This view is limited to 50 files because it contains too many changes.  See the raw diff here.
Files changed (50) hide show
  1. app.py +2 -1
  2. lavis/__init__.py +31 -0
  3. lavis/common/config.py +468 -0
  4. lavis/common/dist_utils.py +137 -0
  5. lavis/common/gradcam.py +24 -0
  6. lavis/common/logger.py +195 -0
  7. lavis/common/optims.py +117 -0
  8. lavis/common/registry.py +329 -0
  9. lavis/common/utils.py +424 -0
  10. lavis/common/vqa_tools/__init__.py +8 -0
  11. lavis/common/vqa_tools/vqa.py +211 -0
  12. lavis/common/vqa_tools/vqa_eval.py +324 -0
  13. lavis/configs/datasets/aokvqa/defaults.yaml +35 -0
  14. lavis/configs/datasets/avsd/defaults_dial.yaml +24 -0
  15. lavis/configs/datasets/coco/defaults_cap.yaml +28 -0
  16. lavis/configs/datasets/coco/defaults_ret.yaml +27 -0
  17. lavis/configs/datasets/coco/defaults_vqa.yaml +41 -0
  18. lavis/configs/datasets/coco/eval_vqa.yaml +27 -0
  19. lavis/configs/datasets/conceptual_caption/defaults_12m.yaml +20 -0
  20. lavis/configs/datasets/conceptual_caption/defaults_3m.yaml +20 -0
  21. lavis/configs/datasets/didemo/defaults_ret.yaml +25 -0
  22. lavis/configs/datasets/flickr30k/defaults.yaml +24 -0
  23. lavis/configs/datasets/gqa/balanced_testdev.yaml +30 -0
  24. lavis/configs/datasets/gqa/balanced_val.yaml +30 -0
  25. lavis/configs/datasets/gqa/defaults.yaml +36 -0
  26. lavis/configs/datasets/imagenet/defaults.yaml +15 -0
  27. lavis/configs/datasets/laion/defaults_2B_multi.yaml +13 -0
  28. lavis/configs/datasets/msrvtt/defaults_cap.yaml +24 -0
  29. lavis/configs/datasets/msrvtt/defaults_qa.yaml +27 -0
  30. lavis/configs/datasets/msrvtt/defaults_ret.yaml +24 -0
  31. lavis/configs/datasets/msvd/defaults_cap.yaml +24 -0
  32. lavis/configs/datasets/msvd/defaults_qa.yaml +29 -0
  33. lavis/configs/datasets/nlvr/defaults.yaml +24 -0
  34. lavis/configs/datasets/nocaps/defaults.yaml +22 -0
  35. lavis/configs/datasets/okvqa/defaults.yaml +37 -0
  36. lavis/configs/datasets/sbu_caption/defaults.yaml +22 -0
  37. lavis/configs/datasets/snli_ve/defaults.yaml +25 -0
  38. lavis/configs/datasets/vatex/defaults_cap.yaml +24 -0
  39. lavis/configs/datasets/vg/defaults_caption.yaml +18 -0
  40. lavis/configs/datasets/vg/defaults_vqa.yaml +18 -0
  41. lavis/configs/default.yaml +10 -0
  42. lavis/configs/models/albef_classification_ve.yaml +40 -0
  43. lavis/configs/models/albef_feature_extractor.yaml +30 -0
  44. lavis/configs/models/albef_nlvr.yaml +42 -0
  45. lavis/configs/models/albef_pretrain_base.yaml +38 -0
  46. lavis/configs/models/albef_retrieval_coco.yaml +46 -0
  47. lavis/configs/models/albef_retrieval_flickr.yaml +46 -0
  48. lavis/configs/models/albef_vqav2.yaml +40 -0
  49. lavis/configs/models/alpro_qa_msrvtt.yaml +44 -0
  50. lavis/configs/models/alpro_qa_msvd.yaml +43 -0
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import os
2
  import subprocess
3
 
4
- subprocess.run("pip install salesforce-lavis --no-deps", shell=True)
 
5
 
6
  from PIL import Image
7
  import gradio as gr
 
1
  import os
2
  import subprocess
3
 
4
+ #subprocess.run("pip install salesforce-lavis --no-deps", shell=True)
5
+ # https://github.com/salesforce/BLIP/issues/165
6
 
7
  from PIL import Image
8
  import gradio as gr
lavis/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+ import sys
10
+
11
+ from omegaconf import OmegaConf
12
+
13
+ from lavis.common.registry import registry
14
+
15
+ from lavis.datasets.builders import *
16
+ from lavis.models import *
17
+ from lavis.processors import *
18
+ from lavis.tasks import *
19
+
20
+
21
+ root_dir = os.path.dirname(os.path.abspath(__file__))
22
+ default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
23
+
24
+ registry.register_path("library_root", root_dir)
25
+ repo_root = os.path.join(root_dir, "..")
26
+ registry.register_path("repo_root", repo_root)
27
+ cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
28
+ registry.register_path("cache_root", cache_root)
29
+
30
+ registry.register("MAX_INT", sys.maxsize)
31
+ registry.register("SPLIT_NAMES", ["train", "val", "test"])
lavis/common/config.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import json
10
+ from typing import Dict
11
+
12
+ from omegaconf import OmegaConf
13
+ from lavis.common.registry import registry
14
+
15
+
16
+ class Config:
17
+ def __init__(self, args):
18
+ self.config = {}
19
+
20
+ self.args = args
21
+
22
+ # Register the config and configuration for setup
23
+ registry.register("configuration", self)
24
+
25
+ user_config = self._build_opt_list(self.args.options)
26
+
27
+ config = OmegaConf.load(self.args.cfg_path)
28
+
29
+ runner_config = self.build_runner_config(config)
30
+ model_config = self.build_model_config(config, **user_config)
31
+ dataset_config = self.build_dataset_config(config)
32
+
33
+ # Validate the user-provided runner configuration
34
+ # model and dataset configuration are supposed to be validated by the respective classes
35
+ # [TODO] validate the model/dataset configuration
36
+ # self._validate_runner_config(runner_config)
37
+
38
+ # Override the default configuration with user options.
39
+ self.config = OmegaConf.merge(
40
+ runner_config, model_config, dataset_config, user_config
41
+ )
42
+
43
+ def _validate_runner_config(self, runner_config):
44
+ """
45
+ This method validates the configuration, such that
46
+ 1) all the user specified options are valid;
47
+ 2) no type mismatches between the user specified options and the config.
48
+ """
49
+ runner_config_validator = create_runner_config_validator()
50
+ runner_config_validator.validate(runner_config)
51
+
52
+ def _build_opt_list(self, opts):
53
+ opts_dot_list = self._convert_to_dot_list(opts)
54
+ return OmegaConf.from_dotlist(opts_dot_list)
55
+
56
+ @staticmethod
57
+ def build_model_config(config, **kwargs):
58
+ model = config.get("model", None)
59
+ assert model is not None, "Missing model configuration file."
60
+
61
+ model_cls = registry.get_model_class(model.arch)
62
+ assert model_cls is not None, f"Model '{model.arch}' has not been registered."
63
+
64
+ model_type = kwargs.get("model.model_type", None)
65
+ if not model_type:
66
+ model_type = model.get("model_type", None)
67
+ # else use the model type selected by user.
68
+
69
+ assert model_type is not None, "Missing model_type."
70
+
71
+ model_config_path = model_cls.default_config_path(model_type=model_type)
72
+
73
+ model_config = OmegaConf.create()
74
+ # hiararchy override, customized config > default config
75
+ model_config = OmegaConf.merge(
76
+ model_config,
77
+ OmegaConf.load(model_config_path),
78
+ {"model": config["model"]},
79
+ )
80
+
81
+ return model_config
82
+
83
+ @staticmethod
84
+ def build_runner_config(config):
85
+ return {"run": config.run}
86
+
87
+ @staticmethod
88
+ def build_dataset_config(config):
89
+ datasets = config.get("datasets", None)
90
+ if datasets is None:
91
+ raise KeyError(
92
+ "Expecting 'datasets' as the root key for dataset configuration."
93
+ )
94
+
95
+ dataset_config = OmegaConf.create()
96
+
97
+ for dataset_name in datasets:
98
+ builder_cls = registry.get_builder_class(dataset_name)
99
+
100
+ dataset_config_type = datasets[dataset_name].get("type", "default")
101
+ dataset_config_path = builder_cls.default_config_path(
102
+ type=dataset_config_type
103
+ )
104
+
105
+ # hiararchy override, customized config > default config
106
+ dataset_config = OmegaConf.merge(
107
+ dataset_config,
108
+ OmegaConf.load(dataset_config_path),
109
+ {"datasets": {dataset_name: config["datasets"][dataset_name]}},
110
+ )
111
+
112
+ return dataset_config
113
+
114
+ def _convert_to_dot_list(self, opts):
115
+ if opts is None:
116
+ opts = []
117
+
118
+ if len(opts) == 0:
119
+ return opts
120
+
121
+ has_equal = opts[0].find("=") != -1
122
+
123
+ if has_equal:
124
+ return opts
125
+
126
+ return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
127
+
128
+ def get_config(self):
129
+ return self.config
130
+
131
+ @property
132
+ def run_cfg(self):
133
+ return self.config.run
134
+
135
+ @property
136
+ def datasets_cfg(self):
137
+ return self.config.datasets
138
+
139
+ @property
140
+ def model_cfg(self):
141
+ return self.config.model
142
+
143
+ def pretty_print(self):
144
+ logging.info("\n===== Running Parameters =====")
145
+ logging.info(self._convert_node_to_json(self.config.run))
146
+
147
+ logging.info("\n====== Dataset Attributes ======")
148
+ datasets = self.config.datasets
149
+
150
+ for dataset in datasets:
151
+ if dataset in self.config.datasets:
152
+ logging.info(f"\n======== {dataset} =======")
153
+ dataset_config = self.config.datasets[dataset]
154
+ logging.info(self._convert_node_to_json(dataset_config))
155
+ else:
156
+ logging.warning(f"No dataset named '{dataset}' in config. Skipping")
157
+
158
+ logging.info(f"\n====== Model Attributes ======")
159
+ logging.info(self._convert_node_to_json(self.config.model))
160
+
161
+ def _convert_node_to_json(self, node):
162
+ container = OmegaConf.to_container(node, resolve=True)
163
+ return json.dumps(container, indent=4, sort_keys=True)
164
+
165
+ def to_dict(self):
166
+ return OmegaConf.to_container(self.config)
167
+
168
+
169
+ def node_to_dict(node):
170
+ return OmegaConf.to_container(node)
171
+
172
+
173
+ class ConfigValidator:
174
+ """
175
+ This is a preliminary implementation to centralize and validate the configuration.
176
+ May be altered in the future.
177
+
178
+ A helper class to validate configurations from yaml file.
179
+
180
+ This serves the following purposes:
181
+ 1. Ensure all the options in the yaml are defined, raise error if not.
182
+ 2. when type mismatches are found, the validator will raise an error.
183
+ 3. a central place to store and display helpful messages for supported configurations.
184
+
185
+ """
186
+
187
+ class _Argument:
188
+ def __init__(self, name, choices=None, type=None, help=None):
189
+ self.name = name
190
+ self.val = None
191
+ self.choices = choices
192
+ self.type = type
193
+ self.help = help
194
+
195
+ def __str__(self):
196
+ s = f"{self.name}={self.val}"
197
+ if self.type is not None:
198
+ s += f", ({self.type})"
199
+ if self.choices is not None:
200
+ s += f", choices: {self.choices}"
201
+ if self.help is not None:
202
+ s += f", ({self.help})"
203
+ return s
204
+
205
+ def __init__(self, description):
206
+ self.description = description
207
+
208
+ self.arguments = dict()
209
+
210
+ self.parsed_args = None
211
+
212
+ def __getitem__(self, key):
213
+ assert self.parsed_args is not None, "No arguments parsed yet."
214
+
215
+ return self.parsed_args[key]
216
+
217
+ def __str__(self) -> str:
218
+ return self.format_help()
219
+
220
+ def add_argument(self, *args, **kwargs):
221
+ """
222
+ Assume the first argument is the name of the argument.
223
+ """
224
+ self.arguments[args[0]] = self._Argument(*args, **kwargs)
225
+
226
+ def validate(self, config=None):
227
+ """
228
+ Convert yaml config (dict-like) to list, required by argparse.
229
+ """
230
+ for k, v in config.items():
231
+ assert (
232
+ k in self.arguments
233
+ ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
234
+
235
+ if self.arguments[k].type is not None:
236
+ try:
237
+ self.arguments[k].val = self.arguments[k].type(v)
238
+ except ValueError:
239
+ raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
240
+
241
+ if self.arguments[k].choices is not None:
242
+ assert (
243
+ v in self.arguments[k].choices
244
+ ), f"""{k} must be one of {self.arguments[k].choices}."""
245
+
246
+ return config
247
+
248
+ def format_arguments(self):
249
+ return str([f"{k}" for k in sorted(self.arguments.keys())])
250
+
251
+ def format_help(self):
252
+ # description + key-value pair string for each argument
253
+ help_msg = str(self.description)
254
+ return help_msg + ", available arguments: " + self.format_arguments()
255
+
256
+ def print_help(self):
257
+ # display help message
258
+ print(self.format_help())
259
+
260
+
261
+ def create_runner_config_validator():
262
+ validator = ConfigValidator(description="Runner configurations")
263
+
264
+ validator.add_argument(
265
+ "runner",
266
+ type=str,
267
+ choices=["runner_base", "runner_iter"],
268
+ help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
269
+ runner runs based on iters. Default: runner_base""",
270
+ )
271
+ # add argumetns for training dataset ratios
272
+ validator.add_argument(
273
+ "train_dataset_ratios",
274
+ type=Dict[str, float],
275
+ help="""Ratios of training dataset. This is used in iteration-based runner.
276
+ Do not support for epoch-based runner because how to define an epoch becomes tricky.
277
+ Default: None""",
278
+ )
279
+ validator.add_argument(
280
+ "max_iters",
281
+ type=float,
282
+ help="Maximum number of iterations to run.",
283
+ )
284
+ validator.add_argument(
285
+ "max_epoch",
286
+ type=int,
287
+ help="Maximum number of epochs to run.",
288
+ )
289
+ # add arguments for iters_per_inner_epoch
290
+ validator.add_argument(
291
+ "iters_per_inner_epoch",
292
+ type=float,
293
+ help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
294
+ )
295
+ lr_scheds_choices = registry.list_lr_schedulers()
296
+ validator.add_argument(
297
+ "lr_sched",
298
+ type=str,
299
+ choices=lr_scheds_choices,
300
+ help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
301
+ )
302
+ task_choices = registry.list_tasks()
303
+ validator.add_argument(
304
+ "task",
305
+ type=str,
306
+ choices=task_choices,
307
+ help="Task to use, from {}".format(task_choices),
308
+ )
309
+ # add arguments for init_lr
310
+ validator.add_argument(
311
+ "init_lr",
312
+ type=float,
313
+ help="Initial learning rate. This will be the learning rate after warmup and before decay.",
314
+ )
315
+ # add arguments for min_lr
316
+ validator.add_argument(
317
+ "min_lr",
318
+ type=float,
319
+ help="Minimum learning rate (after decay).",
320
+ )
321
+ # add arguments for warmup_lr
322
+ validator.add_argument(
323
+ "warmup_lr",
324
+ type=float,
325
+ help="Starting learning rate for warmup.",
326
+ )
327
+ # add arguments for learning rate decay rate
328
+ validator.add_argument(
329
+ "lr_decay_rate",
330
+ type=float,
331
+ help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
332
+ )
333
+ # add arguments for weight decay
334
+ validator.add_argument(
335
+ "weight_decay",
336
+ type=float,
337
+ help="Weight decay rate.",
338
+ )
339
+ # add arguments for training batch size
340
+ validator.add_argument(
341
+ "batch_size_train",
342
+ type=int,
343
+ help="Training batch size.",
344
+ )
345
+ # add arguments for evaluation batch size
346
+ validator.add_argument(
347
+ "batch_size_eval",
348
+ type=int,
349
+ help="Evaluation batch size, including validation and testing.",
350
+ )
351
+ # add arguments for number of workers for data loading
352
+ validator.add_argument(
353
+ "num_workers",
354
+ help="Number of workers for data loading.",
355
+ )
356
+ # add arguments for warm up steps
357
+ validator.add_argument(
358
+ "warmup_steps",
359
+ type=int,
360
+ help="Number of warmup steps. Required if a warmup schedule is used.",
361
+ )
362
+ # add arguments for random seed
363
+ validator.add_argument(
364
+ "seed",
365
+ type=int,
366
+ help="Random seed.",
367
+ )
368
+ # add arguments for output directory
369
+ validator.add_argument(
370
+ "output_dir",
371
+ type=str,
372
+ help="Output directory to save checkpoints and logs.",
373
+ )
374
+ # add arguments for whether only use evaluation
375
+ validator.add_argument(
376
+ "evaluate",
377
+ help="Whether to only evaluate the model. If true, training will not be performed.",
378
+ )
379
+ # add arguments for splits used for training, e.g. ["train", "val"]
380
+ validator.add_argument(
381
+ "train_splits",
382
+ type=list,
383
+ help="Splits to use for training.",
384
+ )
385
+ # add arguments for splits used for validation, e.g. ["val"]
386
+ validator.add_argument(
387
+ "valid_splits",
388
+ type=list,
389
+ help="Splits to use for validation. If not provided, will skip the validation.",
390
+ )
391
+ # add arguments for splits used for testing, e.g. ["test"]
392
+ validator.add_argument(
393
+ "test_splits",
394
+ type=list,
395
+ help="Splits to use for testing. If not provided, will skip the testing.",
396
+ )
397
+ # add arguments for accumulating gradient for iterations
398
+ validator.add_argument(
399
+ "accum_grad_iters",
400
+ type=int,
401
+ help="Number of iterations to accumulate gradient for.",
402
+ )
403
+
404
+ # ====== distributed training ======
405
+ validator.add_argument(
406
+ "device",
407
+ type=str,
408
+ choices=["cpu", "cuda"],
409
+ help="Device to use. Support 'cuda' or 'cpu' as for now.",
410
+ )
411
+ validator.add_argument(
412
+ "world_size",
413
+ type=int,
414
+ help="Number of processes participating in the job.",
415
+ )
416
+ validator.add_argument("dist_url", type=str)
417
+ validator.add_argument("distributed", type=bool)
418
+ # add arguments to opt using distributed sampler during evaluation or not
419
+ validator.add_argument(
420
+ "use_dist_eval_sampler",
421
+ type=bool,
422
+ help="Whether to use distributed sampler during evaluation or not.",
423
+ )
424
+
425
+ # ====== task specific ======
426
+ # generation task specific arguments
427
+ # add arguments for maximal length of text output
428
+ validator.add_argument(
429
+ "max_len",
430
+ type=int,
431
+ help="Maximal length of text output.",
432
+ )
433
+ # add arguments for minimal length of text output
434
+ validator.add_argument(
435
+ "min_len",
436
+ type=int,
437
+ help="Minimal length of text output.",
438
+ )
439
+ # add arguments number of beams
440
+ validator.add_argument(
441
+ "num_beams",
442
+ type=int,
443
+ help="Number of beams used for beam search.",
444
+ )
445
+
446
+ # vqa task specific arguments
447
+ # add arguments for number of answer candidates
448
+ validator.add_argument(
449
+ "num_ans_candidates",
450
+ type=int,
451
+ help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
452
+ )
453
+ # add arguments for inference method
454
+ validator.add_argument(
455
+ "inference_method",
456
+ type=str,
457
+ choices=["genearte", "rank"],
458
+ help="""Inference method to use for question answering. If rank, requires a answer list.""",
459
+ )
460
+
461
+ # ====== model specific ======
462
+ validator.add_argument(
463
+ "k_test",
464
+ type=int,
465
+ help="Number of top k most similar samples from ITC/VTC selection to be tested.",
466
+ )
467
+
468
+ return validator
lavis/common/dist_utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import functools
10
+ import os
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import timm.models.hub as timm_hub
15
+
16
+
17
+ def setup_for_distributed(is_master):
18
+ """
19
+ This function disables printing when not in master process
20
+ """
21
+ import builtins as __builtin__
22
+
23
+ builtin_print = __builtin__.print
24
+
25
+ def print(*args, **kwargs):
26
+ force = kwargs.pop("force", False)
27
+ if is_master or force:
28
+ builtin_print(*args, **kwargs)
29
+
30
+ __builtin__.print = print
31
+
32
+
33
+ def is_dist_avail_and_initialized():
34
+ if not dist.is_available():
35
+ return False
36
+ if not dist.is_initialized():
37
+ return False
38
+ return True
39
+
40
+
41
+ def get_world_size():
42
+ if not is_dist_avail_and_initialized():
43
+ return 1
44
+ return dist.get_world_size()
45
+
46
+
47
+ def get_rank():
48
+ if not is_dist_avail_and_initialized():
49
+ return 0
50
+ return dist.get_rank()
51
+
52
+
53
+ def is_main_process():
54
+ return get_rank() == 0
55
+
56
+
57
+ def init_distributed_mode(args):
58
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
59
+ args.rank = int(os.environ["RANK"])
60
+ args.world_size = int(os.environ["WORLD_SIZE"])
61
+ args.gpu = int(os.environ["LOCAL_RANK"])
62
+ elif "SLURM_PROCID" in os.environ:
63
+ args.rank = int(os.environ["SLURM_PROCID"])
64
+ args.gpu = args.rank % torch.cuda.device_count()
65
+ else:
66
+ print("Not using distributed mode")
67
+ args.distributed = False
68
+ return
69
+
70
+ args.distributed = True
71
+
72
+ torch.cuda.set_device(args.gpu)
73
+ args.dist_backend = "nccl"
74
+ print(
75
+ "| distributed init (rank {}, world {}): {}".format(
76
+ args.rank, args.world_size, args.dist_url
77
+ ),
78
+ flush=True,
79
+ )
80
+ torch.distributed.init_process_group(
81
+ backend=args.dist_backend,
82
+ init_method=args.dist_url,
83
+ world_size=args.world_size,
84
+ rank=args.rank,
85
+ timeout=datetime.timedelta(
86
+ days=365
87
+ ), # allow auto-downloading and de-compressing
88
+ )
89
+ torch.distributed.barrier()
90
+ setup_for_distributed(args.rank == 0)
91
+
92
+
93
+ def get_dist_info():
94
+ if torch.__version__ < "1.0":
95
+ initialized = dist._initialized
96
+ else:
97
+ initialized = dist.is_initialized()
98
+ if initialized:
99
+ rank = dist.get_rank()
100
+ world_size = dist.get_world_size()
101
+ else: # non-distributed training
102
+ rank = 0
103
+ world_size = 1
104
+ return rank, world_size
105
+
106
+
107
+ def main_process(func):
108
+ @functools.wraps(func)
109
+ def wrapper(*args, **kwargs):
110
+ rank, _ = get_dist_info()
111
+ if rank == 0:
112
+ return func(*args, **kwargs)
113
+
114
+ return wrapper
115
+
116
+
117
+ def download_cached_file(url, check_hash=True, progress=False):
118
+ """
119
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
120
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
121
+ """
122
+
123
+ def get_cached_file_path():
124
+ # a hack to sync the file path across processes
125
+ parts = torch.hub.urlparse(url)
126
+ filename = os.path.basename(parts.path)
127
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
128
+
129
+ return cached_file
130
+
131
+ if is_main_process():
132
+ timm_hub.download_cached_file(url, check_hash, progress)
133
+
134
+ if is_dist_avail_and_initialized():
135
+ dist.barrier()
136
+
137
+ return get_cached_file_path()
lavis/common/gradcam.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from matplotlib import pyplot as plt
3
+ from scipy.ndimage import filters
4
+ from skimage import transform as skimage_transform
5
+
6
+
7
+ def getAttMap(img, attMap, blur=True, overlap=True):
8
+ attMap -= attMap.min()
9
+ if attMap.max() > 0:
10
+ attMap /= attMap.max()
11
+ attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
12
+ if blur:
13
+ attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
14
+ attMap -= attMap.min()
15
+ attMap /= attMap.max()
16
+ cmap = plt.get_cmap("jet")
17
+ attMapV = cmap(attMap)
18
+ attMapV = np.delete(attMapV, 3, 2)
19
+ if overlap:
20
+ attMap = (
21
+ 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
22
+ + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
23
+ )
24
+ return attMap
lavis/common/logger.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import logging
10
+ import time
11
+ from collections import defaultdict, deque
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+
16
+ from lavis.common import dist_utils
17
+
18
+
19
+ class SmoothedValue(object):
20
+ """Track a series of values and provide access to smoothed values over a
21
+ window or the global series average.
22
+ """
23
+
24
+ def __init__(self, window_size=20, fmt=None):
25
+ if fmt is None:
26
+ fmt = "{median:.4f} ({global_avg:.4f})"
27
+ self.deque = deque(maxlen=window_size)
28
+ self.total = 0.0
29
+ self.count = 0
30
+ self.fmt = fmt
31
+
32
+ def update(self, value, n=1):
33
+ self.deque.append(value)
34
+ self.count += n
35
+ self.total += value * n
36
+
37
+ def synchronize_between_processes(self):
38
+ """
39
+ Warning: does not synchronize the deque!
40
+ """
41
+ if not dist_utils.is_dist_avail_and_initialized():
42
+ return
43
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
44
+ dist.barrier()
45
+ dist.all_reduce(t)
46
+ t = t.tolist()
47
+ self.count = int(t[0])
48
+ self.total = t[1]
49
+
50
+ @property
51
+ def median(self):
52
+ d = torch.tensor(list(self.deque))
53
+ return d.median().item()
54
+
55
+ @property
56
+ def avg(self):
57
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
58
+ return d.mean().item()
59
+
60
+ @property
61
+ def global_avg(self):
62
+ return self.total / self.count
63
+
64
+ @property
65
+ def max(self):
66
+ return max(self.deque)
67
+
68
+ @property
69
+ def value(self):
70
+ return self.deque[-1]
71
+
72
+ def __str__(self):
73
+ return self.fmt.format(
74
+ median=self.median,
75
+ avg=self.avg,
76
+ global_avg=self.global_avg,
77
+ max=self.max,
78
+ value=self.value,
79
+ )
80
+
81
+
82
+ class MetricLogger(object):
83
+ def __init__(self, delimiter="\t"):
84
+ self.meters = defaultdict(SmoothedValue)
85
+ self.delimiter = delimiter
86
+
87
+ def update(self, **kwargs):
88
+ for k, v in kwargs.items():
89
+ if isinstance(v, torch.Tensor):
90
+ v = v.item()
91
+ assert isinstance(v, (float, int))
92
+ self.meters[k].update(v)
93
+
94
+ def __getattr__(self, attr):
95
+ if attr in self.meters:
96
+ return self.meters[attr]
97
+ if attr in self.__dict__:
98
+ return self.__dict__[attr]
99
+ raise AttributeError(
100
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
101
+ )
102
+
103
+ def __str__(self):
104
+ loss_str = []
105
+ for name, meter in self.meters.items():
106
+ loss_str.append("{}: {}".format(name, str(meter)))
107
+ return self.delimiter.join(loss_str)
108
+
109
+ def global_avg(self):
110
+ loss_str = []
111
+ for name, meter in self.meters.items():
112
+ loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
113
+ return self.delimiter.join(loss_str)
114
+
115
+ def synchronize_between_processes(self):
116
+ for meter in self.meters.values():
117
+ meter.synchronize_between_processes()
118
+
119
+ def add_meter(self, name, meter):
120
+ self.meters[name] = meter
121
+
122
+ def log_every(self, iterable, print_freq, header=None):
123
+ i = 0
124
+ if not header:
125
+ header = ""
126
+ start_time = time.time()
127
+ end = time.time()
128
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
129
+ data_time = SmoothedValue(fmt="{avg:.4f}")
130
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
131
+ log_msg = [
132
+ header,
133
+ "[{0" + space_fmt + "}/{1}]",
134
+ "eta: {eta}",
135
+ "{meters}",
136
+ "time: {time}",
137
+ "data: {data}",
138
+ ]
139
+ if torch.cuda.is_available():
140
+ log_msg.append("max mem: {memory:.0f}")
141
+ log_msg = self.delimiter.join(log_msg)
142
+ MB = 1024.0 * 1024.0
143
+ for obj in iterable:
144
+ data_time.update(time.time() - end)
145
+ yield obj
146
+ iter_time.update(time.time() - end)
147
+ if i % print_freq == 0 or i == len(iterable) - 1:
148
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
149
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
150
+ if torch.cuda.is_available():
151
+ print(
152
+ log_msg.format(
153
+ i,
154
+ len(iterable),
155
+ eta=eta_string,
156
+ meters=str(self),
157
+ time=str(iter_time),
158
+ data=str(data_time),
159
+ memory=torch.cuda.max_memory_allocated() / MB,
160
+ )
161
+ )
162
+ else:
163
+ print(
164
+ log_msg.format(
165
+ i,
166
+ len(iterable),
167
+ eta=eta_string,
168
+ meters=str(self),
169
+ time=str(iter_time),
170
+ data=str(data_time),
171
+ )
172
+ )
173
+ i += 1
174
+ end = time.time()
175
+ total_time = time.time() - start_time
176
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
177
+ print(
178
+ "{} Total time: {} ({:.4f} s / it)".format(
179
+ header, total_time_str, total_time / len(iterable)
180
+ )
181
+ )
182
+
183
+
184
+ class AttrDict(dict):
185
+ def __init__(self, *args, **kwargs):
186
+ super(AttrDict, self).__init__(*args, **kwargs)
187
+ self.__dict__ = self
188
+
189
+
190
+ def setup_logger():
191
+ logging.basicConfig(
192
+ level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
193
+ format="%(asctime)s [%(levelname)s] %(message)s",
194
+ handlers=[logging.StreamHandler()],
195
+ )
lavis/common/optims.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import math
9
+
10
+ from lavis.common.registry import registry
11
+
12
+
13
+ @registry.register_lr_scheduler("linear_warmup_step_lr")
14
+ class LinearWarmupStepLRScheduler:
15
+ def __init__(
16
+ self,
17
+ optimizer,
18
+ max_epoch,
19
+ min_lr,
20
+ init_lr,
21
+ decay_rate=1,
22
+ warmup_start_lr=-1,
23
+ warmup_steps=0,
24
+ **kwargs
25
+ ):
26
+ self.optimizer = optimizer
27
+
28
+ self.max_epoch = max_epoch
29
+ self.min_lr = min_lr
30
+
31
+ self.decay_rate = decay_rate
32
+
33
+ self.init_lr = init_lr
34
+ self.warmup_steps = warmup_steps
35
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
36
+
37
+ def step(self, cur_epoch, cur_step):
38
+ if cur_epoch == 0:
39
+ warmup_lr_schedule(
40
+ step=cur_step,
41
+ optimizer=self.optimizer,
42
+ max_step=self.warmup_steps,
43
+ init_lr=self.warmup_start_lr,
44
+ max_lr=self.init_lr,
45
+ )
46
+ else:
47
+ step_lr_schedule(
48
+ epoch=cur_epoch,
49
+ optimizer=self.optimizer,
50
+ init_lr=self.init_lr,
51
+ min_lr=self.min_lr,
52
+ decay_rate=self.decay_rate,
53
+ )
54
+
55
+
56
+ @registry.register_lr_scheduler("linear_warmup_cosine_lr")
57
+ class LinearWarmupCosineLRScheduler:
58
+ def __init__(
59
+ self,
60
+ optimizer,
61
+ max_epoch,
62
+ min_lr,
63
+ init_lr,
64
+ warmup_steps=0,
65
+ warmup_start_lr=-1,
66
+ **kwargs
67
+ ):
68
+ self.optimizer = optimizer
69
+
70
+ self.max_epoch = max_epoch
71
+ self.min_lr = min_lr
72
+
73
+ self.init_lr = init_lr
74
+ self.warmup_steps = warmup_steps
75
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
76
+
77
+ def step(self, cur_epoch, cur_step):
78
+ # assuming the warmup iters less than one epoch
79
+ if cur_epoch == 0:
80
+ warmup_lr_schedule(
81
+ step=cur_step,
82
+ optimizer=self.optimizer,
83
+ max_step=self.warmup_steps,
84
+ init_lr=self.warmup_start_lr,
85
+ max_lr=self.init_lr,
86
+ )
87
+ else:
88
+ cosine_lr_schedule(
89
+ epoch=cur_epoch,
90
+ optimizer=self.optimizer,
91
+ max_epoch=self.max_epoch,
92
+ init_lr=self.init_lr,
93
+ min_lr=self.min_lr,
94
+ )
95
+
96
+
97
+ def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
98
+ """Decay the learning rate"""
99
+ lr = (init_lr - min_lr) * 0.5 * (
100
+ 1.0 + math.cos(math.pi * epoch / max_epoch)
101
+ ) + min_lr
102
+ for param_group in optimizer.param_groups:
103
+ param_group["lr"] = lr
104
+
105
+
106
+ def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
107
+ """Warmup the learning rate"""
108
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
109
+ for param_group in optimizer.param_groups:
110
+ param_group["lr"] = lr
111
+
112
+
113
+ def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
114
+ """Decay the learning rate"""
115
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
116
+ for param_group in optimizer.param_groups:
117
+ param_group["lr"] = lr
lavis/common/registry.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+
9
+ class Registry:
10
+ mapping = {
11
+ "builder_name_mapping": {},
12
+ "task_name_mapping": {},
13
+ "processor_name_mapping": {},
14
+ "model_name_mapping": {},
15
+ "lr_scheduler_name_mapping": {},
16
+ "runner_name_mapping": {},
17
+ "state": {},
18
+ "paths": {},
19
+ }
20
+
21
+ @classmethod
22
+ def register_builder(cls, name):
23
+ r"""Register a dataset builder to registry with key 'name'
24
+
25
+ Args:
26
+ name: Key with which the builder will be registered.
27
+
28
+ Usage:
29
+
30
+ from lavis.common.registry import registry
31
+ from lavis.datasets.base_dataset_builder import BaseDatasetBuilder
32
+ """
33
+
34
+ def wrap(builder_cls):
35
+ from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder
36
+
37
+ assert issubclass(
38
+ builder_cls, BaseDatasetBuilder
39
+ ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
40
+ builder_cls
41
+ )
42
+ if name in cls.mapping["builder_name_mapping"]:
43
+ raise KeyError(
44
+ "Name '{}' already registered for {}.".format(
45
+ name, cls.mapping["builder_name_mapping"][name]
46
+ )
47
+ )
48
+ cls.mapping["builder_name_mapping"][name] = builder_cls
49
+ return builder_cls
50
+
51
+ return wrap
52
+
53
+ @classmethod
54
+ def register_task(cls, name):
55
+ r"""Register a task to registry with key 'name'
56
+
57
+ Args:
58
+ name: Key with which the task will be registered.
59
+
60
+ Usage:
61
+
62
+ from lavis.common.registry import registry
63
+ """
64
+
65
+ def wrap(task_cls):
66
+ from lavis.tasks.base_task import BaseTask
67
+
68
+ assert issubclass(
69
+ task_cls, BaseTask
70
+ ), "All tasks must inherit BaseTask class"
71
+ if name in cls.mapping["task_name_mapping"]:
72
+ raise KeyError(
73
+ "Name '{}' already registered for {}.".format(
74
+ name, cls.mapping["task_name_mapping"][name]
75
+ )
76
+ )
77
+ cls.mapping["task_name_mapping"][name] = task_cls
78
+ return task_cls
79
+
80
+ return wrap
81
+
82
+ @classmethod
83
+ def register_model(cls, name):
84
+ r"""Register a task to registry with key 'name'
85
+
86
+ Args:
87
+ name: Key with which the task will be registered.
88
+
89
+ Usage:
90
+
91
+ from lavis.common.registry import registry
92
+ """
93
+
94
+ def wrap(model_cls):
95
+ from lavis.models import BaseModel
96
+
97
+ assert issubclass(
98
+ model_cls, BaseModel
99
+ ), "All models must inherit BaseModel class"
100
+ if name in cls.mapping["model_name_mapping"]:
101
+ raise KeyError(
102
+ "Name '{}' already registered for {}.".format(
103
+ name, cls.mapping["model_name_mapping"][name]
104
+ )
105
+ )
106
+ cls.mapping["model_name_mapping"][name] = model_cls
107
+ return model_cls
108
+
109
+ return wrap
110
+
111
+ @classmethod
112
+ def register_processor(cls, name):
113
+ r"""Register a processor to registry with key 'name'
114
+
115
+ Args:
116
+ name: Key with which the task will be registered.
117
+
118
+ Usage:
119
+
120
+ from lavis.common.registry import registry
121
+ """
122
+
123
+ def wrap(processor_cls):
124
+ from lavis.processors import BaseProcessor
125
+
126
+ assert issubclass(
127
+ processor_cls, BaseProcessor
128
+ ), "All processors must inherit BaseProcessor class"
129
+ if name in cls.mapping["processor_name_mapping"]:
130
+ raise KeyError(
131
+ "Name '{}' already registered for {}.".format(
132
+ name, cls.mapping["processor_name_mapping"][name]
133
+ )
134
+ )
135
+ cls.mapping["processor_name_mapping"][name] = processor_cls
136
+ return processor_cls
137
+
138
+ return wrap
139
+
140
+ @classmethod
141
+ def register_lr_scheduler(cls, name):
142
+ r"""Register a model to registry with key 'name'
143
+
144
+ Args:
145
+ name: Key with which the task will be registered.
146
+
147
+ Usage:
148
+
149
+ from lavis.common.registry import registry
150
+ """
151
+
152
+ def wrap(lr_sched_cls):
153
+ if name in cls.mapping["lr_scheduler_name_mapping"]:
154
+ raise KeyError(
155
+ "Name '{}' already registered for {}.".format(
156
+ name, cls.mapping["lr_scheduler_name_mapping"][name]
157
+ )
158
+ )
159
+ cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
160
+ return lr_sched_cls
161
+
162
+ return wrap
163
+
164
+ @classmethod
165
+ def register_runner(cls, name):
166
+ r"""Register a model to registry with key 'name'
167
+
168
+ Args:
169
+ name: Key with which the task will be registered.
170
+
171
+ Usage:
172
+
173
+ from lavis.common.registry import registry
174
+ """
175
+
176
+ def wrap(runner_cls):
177
+ if name in cls.mapping["runner_name_mapping"]:
178
+ raise KeyError(
179
+ "Name '{}' already registered for {}.".format(
180
+ name, cls.mapping["runner_name_mapping"][name]
181
+ )
182
+ )
183
+ cls.mapping["runner_name_mapping"][name] = runner_cls
184
+ return runner_cls
185
+
186
+ return wrap
187
+
188
+ @classmethod
189
+ def register_path(cls, name, path):
190
+ r"""Register a path to registry with key 'name'
191
+
192
+ Args:
193
+ name: Key with which the path will be registered.
194
+
195
+ Usage:
196
+
197
+ from lavis.common.registry import registry
198
+ """
199
+ assert isinstance(path, str), "All path must be str."
200
+ if name in cls.mapping["paths"]:
201
+ raise KeyError("Name '{}' already registered.".format(name))
202
+ cls.mapping["paths"][name] = path
203
+
204
+ @classmethod
205
+ def register(cls, name, obj):
206
+ r"""Register an item to registry with key 'name'
207
+
208
+ Args:
209
+ name: Key with which the item will be registered.
210
+
211
+ Usage::
212
+
213
+ from lavis.common.registry import registry
214
+
215
+ registry.register("config", {})
216
+ """
217
+ path = name.split(".")
218
+ current = cls.mapping["state"]
219
+
220
+ for part in path[:-1]:
221
+ if part not in current:
222
+ current[part] = {}
223
+ current = current[part]
224
+
225
+ current[path[-1]] = obj
226
+
227
+ # @classmethod
228
+ # def get_trainer_class(cls, name):
229
+ # return cls.mapping["trainer_name_mapping"].get(name, None)
230
+
231
+ @classmethod
232
+ def get_builder_class(cls, name):
233
+ return cls.mapping["builder_name_mapping"].get(name, None)
234
+
235
+ @classmethod
236
+ def get_model_class(cls, name):
237
+ return cls.mapping["model_name_mapping"].get(name, None)
238
+
239
+ @classmethod
240
+ def get_task_class(cls, name):
241
+ return cls.mapping["task_name_mapping"].get(name, None)
242
+
243
+ @classmethod
244
+ def get_processor_class(cls, name):
245
+ return cls.mapping["processor_name_mapping"].get(name, None)
246
+
247
+ @classmethod
248
+ def get_lr_scheduler_class(cls, name):
249
+ return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
250
+
251
+ @classmethod
252
+ def get_runner_class(cls, name):
253
+ return cls.mapping["runner_name_mapping"].get(name, None)
254
+
255
+ @classmethod
256
+ def list_runners(cls):
257
+ return sorted(cls.mapping["runner_name_mapping"].keys())
258
+
259
+ @classmethod
260
+ def list_models(cls):
261
+ return sorted(cls.mapping["model_name_mapping"].keys())
262
+
263
+ @classmethod
264
+ def list_tasks(cls):
265
+ return sorted(cls.mapping["task_name_mapping"].keys())
266
+
267
+ @classmethod
268
+ def list_processors(cls):
269
+ return sorted(cls.mapping["processor_name_mapping"].keys())
270
+
271
+ @classmethod
272
+ def list_lr_schedulers(cls):
273
+ return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
274
+
275
+ @classmethod
276
+ def list_datasets(cls):
277
+ return sorted(cls.mapping["builder_name_mapping"].keys())
278
+
279
+ @classmethod
280
+ def get_path(cls, name):
281
+ return cls.mapping["paths"].get(name, None)
282
+
283
+ @classmethod
284
+ def get(cls, name, default=None, no_warning=False):
285
+ r"""Get an item from registry with key 'name'
286
+
287
+ Args:
288
+ name (string): Key whose value needs to be retrieved.
289
+ default: If passed and key is not in registry, default value will
290
+ be returned with a warning. Default: None
291
+ no_warning (bool): If passed as True, warning when key doesn't exist
292
+ will not be generated. Useful for MMF's
293
+ internal operations. Default: False
294
+ """
295
+ original_name = name
296
+ name = name.split(".")
297
+ value = cls.mapping["state"]
298
+ for subname in name:
299
+ value = value.get(subname, default)
300
+ if value is default:
301
+ break
302
+
303
+ if (
304
+ "writer" in cls.mapping["state"]
305
+ and value == default
306
+ and no_warning is False
307
+ ):
308
+ cls.mapping["state"]["writer"].warning(
309
+ "Key {} is not present in registry, returning default value "
310
+ "of {}".format(original_name, default)
311
+ )
312
+ return value
313
+
314
+ @classmethod
315
+ def unregister(cls, name):
316
+ r"""Remove an item from registry with key 'name'
317
+
318
+ Args:
319
+ name: Key which needs to be removed.
320
+ Usage::
321
+
322
+ from mmf.common.registry import registry
323
+
324
+ config = registry.unregister("config")
325
+ """
326
+ return cls.mapping["state"].pop(name, None)
327
+
328
+
329
+ registry = Registry()
lavis/common/utils.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import io
9
+ import json
10
+ import logging
11
+ import os
12
+ import pickle
13
+ import re
14
+ import shutil
15
+ import urllib
16
+ import urllib.error
17
+ import urllib.request
18
+ from typing import Optional
19
+ from urllib.parse import urlparse
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+ import yaml
24
+ from iopath.common.download import download
25
+ from iopath.common.file_io import file_lock, g_pathmgr
26
+ from lavis.common.registry import registry
27
+ from torch.utils.model_zoo import tqdm
28
+ from torchvision.datasets.utils import (
29
+ check_integrity,
30
+ download_file_from_google_drive,
31
+ extract_archive,
32
+ )
33
+
34
+
35
+ def now():
36
+ from datetime import datetime
37
+
38
+ return datetime.now().strftime("%Y%m%d%H%M")[:-1]
39
+
40
+
41
+ def is_url(url_or_filename):
42
+ parsed = urlparse(url_or_filename)
43
+ return parsed.scheme in ("http", "https")
44
+
45
+
46
+ def get_cache_path(rel_path):
47
+ return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
48
+
49
+
50
+ def get_abs_path(rel_path):
51
+ return os.path.join(registry.get_path("library_root"), rel_path)
52
+
53
+
54
+ def load_json(filename):
55
+ with open(filename, "r") as f:
56
+ return json.load(f)
57
+
58
+
59
+ # The following are adapted from torchvision and vissl
60
+ # torchvision: https://github.com/pytorch/vision
61
+ # vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
62
+
63
+
64
+ def makedir(dir_path):
65
+ """
66
+ Create the directory if it does not exist.
67
+ """
68
+ is_success = False
69
+ try:
70
+ if not g_pathmgr.exists(dir_path):
71
+ g_pathmgr.mkdirs(dir_path)
72
+ is_success = True
73
+ except BaseException:
74
+ print(f"Error creating directory: {dir_path}")
75
+ return is_success
76
+
77
+
78
+ def get_redirected_url(url: str):
79
+ """
80
+ Given a URL, returns the URL it redirects to or the
81
+ original URL in case of no indirection
82
+ """
83
+ import requests
84
+
85
+ with requests.Session() as session:
86
+ with session.get(url, stream=True, allow_redirects=True) as response:
87
+ if response.history:
88
+ return response.url
89
+ else:
90
+ return url
91
+
92
+
93
+ def to_google_drive_download_url(view_url: str) -> str:
94
+ """
95
+ Utility function to transform a view URL of google drive
96
+ to a download URL for google drive
97
+ Example input:
98
+ https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
99
+ Example output:
100
+ https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
101
+ """
102
+ splits = view_url.split("/")
103
+ assert splits[-1] == "view"
104
+ file_id = splits[-2]
105
+ return f"https://drive.google.com/uc?export=download&id={file_id}"
106
+
107
+
108
+ def download_google_drive_url(url: str, output_path: str, output_file_name: str):
109
+ """
110
+ Download a file from google drive
111
+ Downloading an URL from google drive requires confirmation when
112
+ the file of the size is too big (google drive notifies that
113
+ anti-viral checks cannot be performed on such files)
114
+ """
115
+ import requests
116
+
117
+ with requests.Session() as session:
118
+
119
+ # First get the confirmation token and append it to the URL
120
+ with session.get(url, stream=True, allow_redirects=True) as response:
121
+ for k, v in response.cookies.items():
122
+ if k.startswith("download_warning"):
123
+ url = url + "&confirm=" + v
124
+
125
+ # Then download the content of the file
126
+ with session.get(url, stream=True, verify=True) as response:
127
+ makedir(output_path)
128
+ path = os.path.join(output_path, output_file_name)
129
+ total_size = int(response.headers.get("Content-length", 0))
130
+ with open(path, "wb") as file:
131
+ from tqdm import tqdm
132
+
133
+ with tqdm(total=total_size) as progress_bar:
134
+ for block in response.iter_content(
135
+ chunk_size=io.DEFAULT_BUFFER_SIZE
136
+ ):
137
+ file.write(block)
138
+ progress_bar.update(len(block))
139
+
140
+
141
+ def _get_google_drive_file_id(url: str) -> Optional[str]:
142
+ parts = urlparse(url)
143
+
144
+ if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
145
+ return None
146
+
147
+ match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
148
+ if match is None:
149
+ return None
150
+
151
+ return match.group("id")
152
+
153
+
154
+ def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
155
+ with open(filename, "wb") as fh:
156
+ with urllib.request.urlopen(
157
+ urllib.request.Request(url, headers={"User-Agent": "vissl"})
158
+ ) as response:
159
+ with tqdm(total=response.length) as pbar:
160
+ for chunk in iter(lambda: response.read(chunk_size), ""):
161
+ if not chunk:
162
+ break
163
+ pbar.update(chunk_size)
164
+ fh.write(chunk)
165
+
166
+
167
+ def download_url(
168
+ url: str,
169
+ root: str,
170
+ filename: Optional[str] = None,
171
+ md5: Optional[str] = None,
172
+ ) -> None:
173
+ """Download a file from a url and place it in root.
174
+ Args:
175
+ url (str): URL to download file from
176
+ root (str): Directory to place downloaded file in
177
+ filename (str, optional): Name to save the file under.
178
+ If None, use the basename of the URL.
179
+ md5 (str, optional): MD5 checksum of the download. If None, do not check
180
+ """
181
+ root = os.path.expanduser(root)
182
+ if not filename:
183
+ filename = os.path.basename(url)
184
+ fpath = os.path.join(root, filename)
185
+
186
+ makedir(root)
187
+
188
+ # check if file is already present locally
189
+ if check_integrity(fpath, md5):
190
+ print("Using downloaded and verified file: " + fpath)
191
+ return
192
+
193
+ # expand redirect chain if needed
194
+ url = get_redirected_url(url)
195
+
196
+ # check if file is located on Google Drive
197
+ file_id = _get_google_drive_file_id(url)
198
+ if file_id is not None:
199
+ return download_file_from_google_drive(file_id, root, filename, md5)
200
+
201
+ # download the file
202
+ try:
203
+ print("Downloading " + url + " to " + fpath)
204
+ _urlretrieve(url, fpath)
205
+ except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
206
+ if url[:5] == "https":
207
+ url = url.replace("https:", "http:")
208
+ print(
209
+ "Failed download. Trying https -> http instead."
210
+ " Downloading " + url + " to " + fpath
211
+ )
212
+ _urlretrieve(url, fpath)
213
+ else:
214
+ raise e
215
+
216
+ # check integrity of downloaded file
217
+ if not check_integrity(fpath, md5):
218
+ raise RuntimeError("File not found or corrupted.")
219
+
220
+
221
+ def download_and_extract_archive(
222
+ url: str,
223
+ download_root: str,
224
+ extract_root: Optional[str] = None,
225
+ filename: Optional[str] = None,
226
+ md5: Optional[str] = None,
227
+ remove_finished: bool = False,
228
+ ) -> None:
229
+ download_root = os.path.expanduser(download_root)
230
+ if extract_root is None:
231
+ extract_root = download_root
232
+ if not filename:
233
+ filename = os.path.basename(url)
234
+
235
+ download_url(url, download_root, filename, md5)
236
+
237
+ archive = os.path.join(download_root, filename)
238
+ print("Extracting {} to {}".format(archive, extract_root))
239
+ extract_archive(archive, extract_root, remove_finished)
240
+
241
+
242
+ def cache_url(url: str, cache_dir: str) -> str:
243
+ """
244
+ This implementation downloads the remote resource and caches it locally.
245
+ The resource will only be downloaded if not previously requested.
246
+ """
247
+ parsed_url = urlparse(url)
248
+ dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
249
+ makedir(dirname)
250
+ filename = url.split("/")[-1]
251
+ cached = os.path.join(dirname, filename)
252
+ with file_lock(cached):
253
+ if not os.path.isfile(cached):
254
+ logging.info(f"Downloading {url} to {cached} ...")
255
+ cached = download(url, dirname, filename=filename)
256
+ logging.info(f"URL {url} cached in {cached}")
257
+ return cached
258
+
259
+
260
+ # TODO (prigoyal): convert this into RAII-style API
261
+ def create_file_symlink(file1, file2):
262
+ """
263
+ Simply create the symlinks for a given file1 to file2.
264
+ Useful during model checkpointing to symlinks to the
265
+ latest successful checkpoint.
266
+ """
267
+ try:
268
+ if g_pathmgr.exists(file2):
269
+ g_pathmgr.rm(file2)
270
+ g_pathmgr.symlink(file1, file2)
271
+ except Exception as e:
272
+ logging.info(f"Could NOT create symlink. Error: {e}")
273
+
274
+
275
+ def save_file(data, filename, append_to_json=True, verbose=True):
276
+ """
277
+ Common i/o utility to handle saving data to various file formats.
278
+ Supported:
279
+ .pkl, .pickle, .npy, .json
280
+ Specifically for .json, users have the option to either append (default)
281
+ or rewrite by passing in Boolean value to append_to_json.
282
+ """
283
+ if verbose:
284
+ logging.info(f"Saving data to file: {filename}")
285
+ file_ext = os.path.splitext(filename)[1]
286
+ if file_ext in [".pkl", ".pickle"]:
287
+ with g_pathmgr.open(filename, "wb") as fopen:
288
+ pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
289
+ elif file_ext == ".npy":
290
+ with g_pathmgr.open(filename, "wb") as fopen:
291
+ np.save(fopen, data)
292
+ elif file_ext == ".json":
293
+ if append_to_json:
294
+ with g_pathmgr.open(filename, "a") as fopen:
295
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
296
+ fopen.flush()
297
+ else:
298
+ with g_pathmgr.open(filename, "w") as fopen:
299
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
300
+ fopen.flush()
301
+ elif file_ext == ".yaml":
302
+ with g_pathmgr.open(filename, "w") as fopen:
303
+ dump = yaml.dump(data)
304
+ fopen.write(dump)
305
+ fopen.flush()
306
+ else:
307
+ raise Exception(f"Saving {file_ext} is not supported yet")
308
+
309
+ if verbose:
310
+ logging.info(f"Saved data to file: {filename}")
311
+
312
+
313
+ def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
314
+ """
315
+ Common i/o utility to handle loading data from various file formats.
316
+ Supported:
317
+ .pkl, .pickle, .npy, .json
318
+ For the npy files, we support reading the files in mmap_mode.
319
+ If the mmap_mode of reading is not successful, we load data without the
320
+ mmap_mode.
321
+ """
322
+ if verbose:
323
+ logging.info(f"Loading data from file: {filename}")
324
+
325
+ file_ext = os.path.splitext(filename)[1]
326
+ if file_ext == ".txt":
327
+ with g_pathmgr.open(filename, "r") as fopen:
328
+ data = fopen.readlines()
329
+ elif file_ext in [".pkl", ".pickle"]:
330
+ with g_pathmgr.open(filename, "rb") as fopen:
331
+ data = pickle.load(fopen, encoding="latin1")
332
+ elif file_ext == ".npy":
333
+ if mmap_mode:
334
+ try:
335
+ with g_pathmgr.open(filename, "rb") as fopen:
336
+ data = np.load(
337
+ fopen,
338
+ allow_pickle=allow_pickle,
339
+ encoding="latin1",
340
+ mmap_mode=mmap_mode,
341
+ )
342
+ except ValueError as e:
343
+ logging.info(
344
+ f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
345
+ )
346
+ data = np.load(
347
+ filename,
348
+ allow_pickle=allow_pickle,
349
+ encoding="latin1",
350
+ mmap_mode=mmap_mode,
351
+ )
352
+ logging.info("Successfully loaded without g_pathmgr")
353
+ except Exception:
354
+ logging.info("Could not mmap without g_pathmgr. Trying without mmap")
355
+ with g_pathmgr.open(filename, "rb") as fopen:
356
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
357
+ else:
358
+ with g_pathmgr.open(filename, "rb") as fopen:
359
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
360
+ elif file_ext == ".json":
361
+ with g_pathmgr.open(filename, "r") as fopen:
362
+ data = json.load(fopen)
363
+ elif file_ext == ".yaml":
364
+ with g_pathmgr.open(filename, "r") as fopen:
365
+ data = yaml.load(fopen, Loader=yaml.FullLoader)
366
+ elif file_ext == ".csv":
367
+ with g_pathmgr.open(filename, "r") as fopen:
368
+ data = pd.read_csv(fopen)
369
+ else:
370
+ raise Exception(f"Reading from {file_ext} is not supported yet")
371
+ return data
372
+
373
+
374
+ def abspath(resource_path: str):
375
+ """
376
+ Make a path absolute, but take into account prefixes like
377
+ "http://" or "manifold://"
378
+ """
379
+ regex = re.compile(r"^\w+://")
380
+ if regex.match(resource_path) is None:
381
+ return os.path.abspath(resource_path)
382
+ else:
383
+ return resource_path
384
+
385
+
386
+ def makedir(dir_path):
387
+ """
388
+ Create the directory if it does not exist.
389
+ """
390
+ is_success = False
391
+ try:
392
+ if not g_pathmgr.exists(dir_path):
393
+ g_pathmgr.mkdirs(dir_path)
394
+ is_success = True
395
+ except BaseException:
396
+ logging.info(f"Error creating directory: {dir_path}")
397
+ return is_success
398
+
399
+
400
+ def is_url(input_url):
401
+ """
402
+ Check if an input string is a url. look for http(s):// and ignoring the case
403
+ """
404
+ is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
405
+ return is_url
406
+
407
+
408
+ def cleanup_dir(dir):
409
+ """
410
+ Utility for deleting a directory. Useful for cleaning the storage space
411
+ that contains various training artifacts like checkpoints, data etc.
412
+ """
413
+ if os.path.exists(dir):
414
+ logging.info(f"Deleting directory: {dir}")
415
+ shutil.rmtree(dir)
416
+ logging.info(f"Deleted contents of directory: {dir}")
417
+
418
+
419
+ def get_file_size(filename):
420
+ """
421
+ Given a file, get the size of file in MB
422
+ """
423
+ size_in_mb = os.path.getsize(filename) / float(1024**2)
424
+ return size_in_mb
lavis/common/vqa_tools/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ __author__ = "aagrawal"
lavis/common/vqa_tools/vqa.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ __author__ = "aagrawal"
9
+ __version__ = "0.9"
10
+
11
+ # Interface for accessing the VQA dataset.
12
+
13
+ # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
14
+ # (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
15
+
16
+ # The following functions are defined:
17
+ # VQA - VQA class that loads VQA annotation file and prepares data structures.
18
+ # getQuesIds - Get question ids that satisfy given filter conditions.
19
+ # getImgIds - Get image ids that satisfy given filter conditions.
20
+ # loadQA - Load questions and answers with the specified question ids.
21
+ # showQA - Display the specified questions and answers.
22
+ # loadRes - Load result file and create result object.
23
+
24
+ # Help on each function can be accessed by: "help(COCO.function)"
25
+
26
+ import json
27
+ import datetime
28
+ import copy
29
+
30
+
31
+ class VQA:
32
+ def __init__(self, annotation_file=None, question_file=None):
33
+ """
34
+ Constructor of VQA helper class for reading and visualizing questions and answers.
35
+ :param annotation_file (str): location of VQA annotation file
36
+ :return:
37
+ """
38
+ # load dataset
39
+ self.dataset = {}
40
+ self.questions = {}
41
+ self.qa = {}
42
+ self.qqa = {}
43
+ self.imgToQA = {}
44
+ if not annotation_file == None and not question_file == None:
45
+ print("loading VQA annotations and questions into memory...")
46
+ time_t = datetime.datetime.utcnow()
47
+ dataset = json.load(open(annotation_file, "r"))
48
+ questions = json.load(open(question_file, "r"))
49
+ self.dataset = dataset
50
+ self.questions = questions
51
+ self.createIndex()
52
+
53
+ def createIndex(self):
54
+ # create index
55
+ print("creating index...")
56
+ imgToQA = {ann["image_id"]: [] for ann in self.dataset["annotations"]}
57
+ qa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
58
+ qqa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
59
+ for ann in self.dataset["annotations"]:
60
+ imgToQA[ann["image_id"]] += [ann]
61
+ qa[ann["question_id"]] = ann
62
+ for ques in self.questions["questions"]:
63
+ qqa[ques["question_id"]] = ques
64
+ print("index created!")
65
+
66
+ # create class members
67
+ self.qa = qa
68
+ self.qqa = qqa
69
+ self.imgToQA = imgToQA
70
+
71
+ def info(self):
72
+ """
73
+ Print information about the VQA annotation file.
74
+ :return:
75
+ """
76
+ for key, value in self.datset["info"].items():
77
+ print("%s: %s" % (key, value))
78
+
79
+ def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
80
+ """
81
+ Get question ids that satisfy given filter conditions. default skips that filter
82
+ :param imgIds (int array) : get question ids for given imgs
83
+ quesTypes (str array) : get question ids for given question types
84
+ ansTypes (str array) : get question ids for given answer types
85
+ :return: ids (int array) : integer array of question ids
86
+ """
87
+ imgIds = imgIds if type(imgIds) == list else [imgIds]
88
+ quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
89
+ ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
90
+
91
+ if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
92
+ anns = self.dataset["annotations"]
93
+ else:
94
+ if not len(imgIds) == 0:
95
+ anns = sum(
96
+ [self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],
97
+ [],
98
+ )
99
+ else:
100
+ anns = self.dataset["annotations"]
101
+ anns = (
102
+ anns
103
+ if len(quesTypes) == 0
104
+ else [ann for ann in anns if ann["question_type"] in quesTypes]
105
+ )
106
+ anns = (
107
+ anns
108
+ if len(ansTypes) == 0
109
+ else [ann for ann in anns if ann["answer_type"] in ansTypes]
110
+ )
111
+ ids = [ann["question_id"] for ann in anns]
112
+ return ids
113
+
114
+ def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
115
+ """
116
+ Get image ids that satisfy given filter conditions. default skips that filter
117
+ :param quesIds (int array) : get image ids for given question ids
118
+ quesTypes (str array) : get image ids for given question types
119
+ ansTypes (str array) : get image ids for given answer types
120
+ :return: ids (int array) : integer array of image ids
121
+ """
122
+ quesIds = quesIds if type(quesIds) == list else [quesIds]
123
+ quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
124
+ ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
125
+
126
+ if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
127
+ anns = self.dataset["annotations"]
128
+ else:
129
+ if not len(quesIds) == 0:
130
+ anns = sum(
131
+ [self.qa[quesId] for quesId in quesIds if quesId in self.qa], []
132
+ )
133
+ else:
134
+ anns = self.dataset["annotations"]
135
+ anns = (
136
+ anns
137
+ if len(quesTypes) == 0
138
+ else [ann for ann in anns if ann["question_type"] in quesTypes]
139
+ )
140
+ anns = (
141
+ anns
142
+ if len(ansTypes) == 0
143
+ else [ann for ann in anns if ann["answer_type"] in ansTypes]
144
+ )
145
+ ids = [ann["image_id"] for ann in anns]
146
+ return ids
147
+
148
+ def loadQA(self, ids=[]):
149
+ """
150
+ Load questions and answers with the specified question ids.
151
+ :param ids (int array) : integer ids specifying question ids
152
+ :return: qa (object array) : loaded qa objects
153
+ """
154
+ if type(ids) == list:
155
+ return [self.qa[id] for id in ids]
156
+ elif type(ids) == int:
157
+ return [self.qa[ids]]
158
+
159
+ def showQA(self, anns):
160
+ """
161
+ Display the specified annotations.
162
+ :param anns (array of object): annotations to display
163
+ :return: None
164
+ """
165
+ if len(anns) == 0:
166
+ return 0
167
+ for ann in anns:
168
+ quesId = ann["question_id"]
169
+ print("Question: %s" % (self.qqa[quesId]["question"]))
170
+ for ans in ann["answers"]:
171
+ print("Answer %d: %s" % (ans["answer_id"], ans["answer"]))
172
+
173
+ def loadRes(self, resFile, quesFile):
174
+ """
175
+ Load result file and return a result object.
176
+ :param resFile (str) : file name of result file
177
+ :return: res (obj) : result api object
178
+ """
179
+ res = VQA()
180
+ res.questions = json.load(open(quesFile))
181
+ res.dataset["info"] = copy.deepcopy(self.questions["info"])
182
+ res.dataset["task_type"] = copy.deepcopy(self.questions["task_type"])
183
+ res.dataset["data_type"] = copy.deepcopy(self.questions["data_type"])
184
+ res.dataset["data_subtype"] = copy.deepcopy(self.questions["data_subtype"])
185
+ res.dataset["license"] = copy.deepcopy(self.questions["license"])
186
+
187
+ print("Loading and preparing results... ")
188
+ time_t = datetime.datetime.utcnow()
189
+ anns = json.load(open(resFile))
190
+ assert type(anns) == list, "results is not an array of objects"
191
+ annsQuesIds = [ann["question_id"] for ann in anns]
192
+ assert set(annsQuesIds) == set(
193
+ self.getQuesIds()
194
+ ), "Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file."
195
+ for ann in anns:
196
+ quesId = ann["question_id"]
197
+ if res.dataset["task_type"] == "Multiple Choice":
198
+ assert (
199
+ ann["answer"] in self.qqa[quesId]["multiple_choices"]
200
+ ), "predicted answer is not one of the multiple choices"
201
+ qaAnn = self.qa[quesId]
202
+ ann["image_id"] = qaAnn["image_id"]
203
+ ann["question_type"] = qaAnn["question_type"]
204
+ ann["answer_type"] = qaAnn["answer_type"]
205
+ print(
206
+ "DONE (t=%0.2fs)" % ((datetime.datetime.utcnow() - time_t).total_seconds())
207
+ )
208
+
209
+ res.dataset["annotations"] = anns
210
+ res.createIndex()
211
+ return res
lavis/common/vqa_tools/vqa_eval.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ # coding=utf-8
9
+
10
+ __author__ = "aagrawal"
11
+
12
+ # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
13
+ # (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
14
+ import sys
15
+ import re
16
+
17
+
18
+ class VQAEval:
19
+ def __init__(self, vqa=None, vqaRes=None, n=2):
20
+ self.n = n
21
+ self.accuracy = {}
22
+ self.evalQA = {}
23
+ self.evalQuesType = {}
24
+ self.evalAnsType = {}
25
+ self.vqa = vqa
26
+ self.vqaRes = vqaRes
27
+ if vqa is not None:
28
+ self.params = {"question_id": vqa.getQuesIds()}
29
+ self.contractions = {
30
+ "aint": "ain't",
31
+ "arent": "aren't",
32
+ "cant": "can't",
33
+ "couldve": "could've",
34
+ "couldnt": "couldn't",
35
+ "couldn'tve": "couldn't've",
36
+ "couldnt've": "couldn't've",
37
+ "didnt": "didn't",
38
+ "doesnt": "doesn't",
39
+ "dont": "don't",
40
+ "hadnt": "hadn't",
41
+ "hadnt've": "hadn't've",
42
+ "hadn'tve": "hadn't've",
43
+ "hasnt": "hasn't",
44
+ "havent": "haven't",
45
+ "hed": "he'd",
46
+ "hed've": "he'd've",
47
+ "he'dve": "he'd've",
48
+ "hes": "he's",
49
+ "howd": "how'd",
50
+ "howll": "how'll",
51
+ "hows": "how's",
52
+ "Id've": "I'd've",
53
+ "I'dve": "I'd've",
54
+ "Im": "I'm",
55
+ "Ive": "I've",
56
+ "isnt": "isn't",
57
+ "itd": "it'd",
58
+ "itd've": "it'd've",
59
+ "it'dve": "it'd've",
60
+ "itll": "it'll",
61
+ "let's": "let's",
62
+ "maam": "ma'am",
63
+ "mightnt": "mightn't",
64
+ "mightnt've": "mightn't've",
65
+ "mightn'tve": "mightn't've",
66
+ "mightve": "might've",
67
+ "mustnt": "mustn't",
68
+ "mustve": "must've",
69
+ "neednt": "needn't",
70
+ "notve": "not've",
71
+ "oclock": "o'clock",
72
+ "oughtnt": "oughtn't",
73
+ "ow's'at": "'ow's'at",
74
+ "'ows'at": "'ow's'at",
75
+ "'ow'sat": "'ow's'at",
76
+ "shant": "shan't",
77
+ "shed've": "she'd've",
78
+ "she'dve": "she'd've",
79
+ "she's": "she's",
80
+ "shouldve": "should've",
81
+ "shouldnt": "shouldn't",
82
+ "shouldnt've": "shouldn't've",
83
+ "shouldn'tve": "shouldn't've",
84
+ "somebody'd": "somebodyd",
85
+ "somebodyd've": "somebody'd've",
86
+ "somebody'dve": "somebody'd've",
87
+ "somebodyll": "somebody'll",
88
+ "somebodys": "somebody's",
89
+ "someoned": "someone'd",
90
+ "someoned've": "someone'd've",
91
+ "someone'dve": "someone'd've",
92
+ "someonell": "someone'll",
93
+ "someones": "someone's",
94
+ "somethingd": "something'd",
95
+ "somethingd've": "something'd've",
96
+ "something'dve": "something'd've",
97
+ "somethingll": "something'll",
98
+ "thats": "that's",
99
+ "thered": "there'd",
100
+ "thered've": "there'd've",
101
+ "there'dve": "there'd've",
102
+ "therere": "there're",
103
+ "theres": "there's",
104
+ "theyd": "they'd",
105
+ "theyd've": "they'd've",
106
+ "they'dve": "they'd've",
107
+ "theyll": "they'll",
108
+ "theyre": "they're",
109
+ "theyve": "they've",
110
+ "twas": "'twas",
111
+ "wasnt": "wasn't",
112
+ "wed've": "we'd've",
113
+ "we'dve": "we'd've",
114
+ "weve": "we've",
115
+ "werent": "weren't",
116
+ "whatll": "what'll",
117
+ "whatre": "what're",
118
+ "whats": "what's",
119
+ "whatve": "what've",
120
+ "whens": "when's",
121
+ "whered": "where'd",
122
+ "wheres": "where's",
123
+ "whereve": "where've",
124
+ "whod": "who'd",
125
+ "whod've": "who'd've",
126
+ "who'dve": "who'd've",
127
+ "wholl": "who'll",
128
+ "whos": "who's",
129
+ "whove": "who've",
130
+ "whyll": "why'll",
131
+ "whyre": "why're",
132
+ "whys": "why's",
133
+ "wont": "won't",
134
+ "wouldve": "would've",
135
+ "wouldnt": "wouldn't",
136
+ "wouldnt've": "wouldn't've",
137
+ "wouldn'tve": "wouldn't've",
138
+ "yall": "y'all",
139
+ "yall'll": "y'all'll",
140
+ "y'allll": "y'all'll",
141
+ "yall'd've": "y'all'd've",
142
+ "y'alld've": "y'all'd've",
143
+ "y'all'dve": "y'all'd've",
144
+ "youd": "you'd",
145
+ "youd've": "you'd've",
146
+ "you'dve": "you'd've",
147
+ "youll": "you'll",
148
+ "youre": "you're",
149
+ "youve": "you've",
150
+ }
151
+ self.manualMap = {
152
+ "none": "0",
153
+ "zero": "0",
154
+ "one": "1",
155
+ "two": "2",
156
+ "three": "3",
157
+ "four": "4",
158
+ "five": "5",
159
+ "six": "6",
160
+ "seven": "7",
161
+ "eight": "8",
162
+ "nine": "9",
163
+ "ten": "10",
164
+ }
165
+ self.articles = ["a", "an", "the"]
166
+
167
+ self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
168
+ self.commaStrip = re.compile("(\d)(,)(\d)")
169
+ self.punct = [
170
+ ";",
171
+ r"/",
172
+ "[",
173
+ "]",
174
+ '"',
175
+ "{",
176
+ "}",
177
+ "(",
178
+ ")",
179
+ "=",
180
+ "+",
181
+ "\\",
182
+ "_",
183
+ "-",
184
+ ">",
185
+ "<",
186
+ "@",
187
+ "`",
188
+ ",",
189
+ "?",
190
+ "!",
191
+ ]
192
+
193
+ def evaluate(self, quesIds=None):
194
+ if quesIds == None:
195
+ quesIds = [quesId for quesId in self.params["question_id"]]
196
+ gts = {}
197
+ res = {}
198
+ for quesId in quesIds:
199
+ gts[quesId] = self.vqa.qa[quesId]
200
+ res[quesId] = self.vqaRes.qa[quesId]
201
+
202
+ # =================================================
203
+ # Compute accuracy
204
+ # =================================================
205
+ accQA = []
206
+ accQuesType = {}
207
+ accAnsType = {}
208
+ print("computing accuracy")
209
+ step = 0
210
+ for quesId in quesIds:
211
+ resAns = res[quesId]["answer"]
212
+ resAns = resAns.replace("\n", " ")
213
+ resAns = resAns.replace("\t", " ")
214
+ resAns = resAns.strip()
215
+ resAns = self.processPunctuation(resAns)
216
+ resAns = self.processDigitArticle(resAns)
217
+ gtAcc = []
218
+ gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]]
219
+ if len(set(gtAnswers)) > 1:
220
+ for ansDic in gts[quesId]["answers"]:
221
+ ansDic["answer"] = self.processPunctuation(ansDic["answer"])
222
+ for gtAnsDatum in gts[quesId]["answers"]:
223
+ otherGTAns = [
224
+ item for item in gts[quesId]["answers"] if item != gtAnsDatum
225
+ ]
226
+ matchingAns = [item for item in otherGTAns if item["answer"] == resAns]
227
+ acc = min(1, float(len(matchingAns)) / 3)
228
+ gtAcc.append(acc)
229
+ quesType = gts[quesId]["question_type"]
230
+ ansType = gts[quesId]["answer_type"]
231
+ avgGTAcc = float(sum(gtAcc)) / len(gtAcc)
232
+ accQA.append(avgGTAcc)
233
+ if quesType not in accQuesType:
234
+ accQuesType[quesType] = []
235
+ accQuesType[quesType].append(avgGTAcc)
236
+ if ansType not in accAnsType:
237
+ accAnsType[ansType] = []
238
+ accAnsType[ansType].append(avgGTAcc)
239
+ self.setEvalQA(quesId, avgGTAcc)
240
+ self.setEvalQuesType(quesId, quesType, avgGTAcc)
241
+ self.setEvalAnsType(quesId, ansType, avgGTAcc)
242
+ if step % 100 == 0:
243
+ self.updateProgress(step / float(len(quesIds)))
244
+ step = step + 1
245
+
246
+ self.setAccuracy(accQA, accQuesType, accAnsType)
247
+ print("Done computing accuracy")
248
+
249
+ def processPunctuation(self, inText):
250
+ outText = inText
251
+ for p in self.punct:
252
+ if (p + " " in inText or " " + p in inText) or (
253
+ re.search(self.commaStrip, inText) != None
254
+ ):
255
+ outText = outText.replace(p, "")
256
+ else:
257
+ outText = outText.replace(p, " ")
258
+ outText = self.periodStrip.sub("", outText, re.UNICODE)
259
+ return outText
260
+
261
+ def processDigitArticle(self, inText):
262
+ outText = []
263
+ tempText = inText.lower().split()
264
+ for word in tempText:
265
+ word = self.manualMap.setdefault(word, word)
266
+ if word not in self.articles:
267
+ outText.append(word)
268
+ else:
269
+ pass
270
+ for wordId, word in enumerate(outText):
271
+ if word in self.contractions:
272
+ outText[wordId] = self.contractions[word]
273
+ outText = " ".join(outText)
274
+ return outText
275
+
276
+ def setAccuracy(self, accQA, accQuesType, accAnsType):
277
+ self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n)
278
+ self.accuracy["perQuestionType"] = {
279
+ quesType: round(
280
+ 100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]),
281
+ self.n,
282
+ )
283
+ for quesType in accQuesType
284
+ }
285
+ self.accuracy["perAnswerType"] = {
286
+ ansType: round(
287
+ 100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n
288
+ )
289
+ for ansType in accAnsType
290
+ }
291
+
292
+ def setEvalQA(self, quesId, acc):
293
+ self.evalQA[quesId] = round(100 * acc, self.n)
294
+
295
+ def setEvalQuesType(self, quesId, quesType, acc):
296
+ if quesType not in self.evalQuesType:
297
+ self.evalQuesType[quesType] = {}
298
+ self.evalQuesType[quesType][quesId] = round(100 * acc, self.n)
299
+
300
+ def setEvalAnsType(self, quesId, ansType, acc):
301
+ if ansType not in self.evalAnsType:
302
+ self.evalAnsType[ansType] = {}
303
+ self.evalAnsType[ansType][quesId] = round(100 * acc, self.n)
304
+
305
+ def updateProgress(self, progress):
306
+ barLength = 20
307
+ status = ""
308
+ if isinstance(progress, int):
309
+ progress = float(progress)
310
+ if not isinstance(progress, float):
311
+ progress = 0
312
+ status = "error: progress var must be float\r\n"
313
+ if progress < 0:
314
+ progress = 0
315
+ status = "Halt...\r\n"
316
+ if progress >= 1:
317
+ progress = 1
318
+ status = "Done...\r\n"
319
+ block = int(round(barLength * progress))
320
+ text = "\rFinshed Percent: [{0}] {1}% {2}".format(
321
+ "#" * block + "-" * (barLength - block), int(progress * 100), status
322
+ )
323
+ sys.stdout.write(text)
324
+ sys.stdout.flush()
lavis/configs/datasets/aokvqa/defaults.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ aok_vqa:
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: images # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ train:
15
+ url:
16
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_train.json
17
+ storage:
18
+ - aokvqa/annotations/aokvqa_v1p0_train.json
19
+ val:
20
+ url:
21
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_val.json
22
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/specialized_vocab_train.json
23
+ storage:
24
+ - aokvqa/annotations/aokvqa_v1p0_val.json
25
+ - aokvqa/annotations/specialized_vocab_train_lavis.json
26
+ # - aokvqa/annotations/large_vocab_train_lavis.json
27
+ test:
28
+ url:
29
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_test.json
30
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/specialized_vocab_train.json
31
+ storage:
32
+ - aokvqa/annotations/aokvqa_v1p0_test.json
33
+ - aokvqa/annotations/specialized_vocab_train_lavis.json
34
+ images:
35
+ storage: coco/images/
lavis/configs/datasets/avsd/defaults_dial.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ avsd_dialogue: # name of the dataset builder
8
+ dataset_card: dataset_card/avsd_dialogue.md
9
+ data_type: features #extracted features of videos (I3D, VGGish) # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ train:
15
+ url: https://storage.googleapis.com/sfr-vision-language-research/datasets/avsd_dstc7_train.json
16
+ storage: avsd/annotations/train.json
17
+ val:
18
+ url: https://storage.googleapis.com/sfr-vision-language-research/datasets/avsd_dstc7_val.json
19
+ storage: avsd/annotations/val.json
20
+ test:
21
+ url: https://storage.googleapis.com/sfr-vision-language-research/datasets/avsd_dstc7_test.json
22
+ storage: avsd/annotations/test.json
23
+ features:
24
+ storage: avsd/features/
lavis/configs/datasets/coco/defaults_cap.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ coco_caption: # name of the dataset builder
8
+ dataset_card: dataset_card/coco_caption.md
9
+ # data_dir: ${env.data_dir}/datasets
10
+ data_type: images # [images|videos|features]
11
+
12
+ build_info:
13
+ # Be careful not to append minus sign (-) before split to avoid itemizing
14
+ annotations:
15
+ train:
16
+ url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json
17
+ md5: aa31ac474cf6250ebb81d18348a07ed8
18
+ storage: coco/annotations/coco_karpathy_train.json
19
+ val:
20
+ url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json
21
+ md5: b273847456ef5580e33713b1f7de52a0
22
+ storage: coco/annotations/coco_karpathy_val.json
23
+ test:
24
+ url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json
25
+ md5: 3ff34b0ef2db02d01c37399f6a2a6cd1
26
+ storage: coco/annotations/coco_karpathy_test.json
27
+ images:
28
+ storage: coco/images/
lavis/configs/datasets/coco/defaults_ret.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ coco_retrieval:
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: images # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ train:
15
+ url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json
16
+ md5: aa31ac474cf6250ebb81d18348a07ed8
17
+ storage: coco/annotations/coco_karpathy_train.json
18
+ val:
19
+ url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json
20
+ md5: b273847456ef5580e33713b1f7de52a0
21
+ storage: coco/annotations/coco_karpathy_val.json
22
+ test:
23
+ url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json
24
+ md5: 3ff34b0ef2db02d01c37399f6a2a6cd1
25
+ storage: coco/annotations/coco_karpathy_test.json
26
+ images:
27
+ storage: coco/images/
lavis/configs/datasets/coco/defaults_vqa.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ coco_vqa:
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: images # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ train:
15
+ url:
16
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_train.json
17
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_val.json
18
+ storage:
19
+ - coco/annotations/vqa_train.json
20
+ - coco/annotations/vqa_val.json
21
+ val:
22
+ url:
23
+ # TODO make this order insensitive
24
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_val_eval.json
25
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/answer_list.json
26
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/v2_OpenEnded_mscoco_val2014_questions.json
27
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/v2_mscoco_val2014_annotations.json
28
+ storage:
29
+ - coco/annotations/vqa_val_eval.json
30
+ - coco/annotations/answer_list.json
31
+ - coco/annotations/v2_OpenEnded_mscoco_val2014_questions.json
32
+ - coco/annotations/v2_mscoco_val2014_annotations.json
33
+ test:
34
+ url:
35
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_test.json
36
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/answer_list.json
37
+ storage:
38
+ - coco/annotations/vqa_test.json
39
+ - coco/annotations/answer_list.json
40
+ images:
41
+ storage: coco/images/
lavis/configs/datasets/coco/eval_vqa.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ coco_vqa:
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: images # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ val:
15
+ url:
16
+ # TODO make this order insensitive
17
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_val_eval.json
18
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/answer_list.json
19
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/v2_OpenEnded_mscoco_val2014_questions.json
20
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/v2_mscoco_val2014_annotations.json
21
+ storage:
22
+ - coco/annotations/vqa_val_eval.json
23
+ - coco/annotations/answer_list.json
24
+ - coco/annotations/v2_OpenEnded_mscoco_val2014_questions.json
25
+ - coco/annotations/v2_mscoco_val2014_annotations.json
26
+ images:
27
+ storage: coco/images/
lavis/configs/datasets/conceptual_caption/defaults_12m.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ conceptual_caption_12m:
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: images # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ train:
15
+ url:
16
+ - /export/home/workspace/datasets/cc12m.json
17
+ storage:
18
+ - conceptual_caption/annotations/cc12m.json
19
+ images:
20
+ storage: conceptual_caption/images_12m
lavis/configs/datasets/conceptual_caption/defaults_3m.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ conceptual_caption_3m:
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: images # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ train:
15
+ url:
16
+ - /export/home/workspace/datasets/cc3m.json
17
+ storage:
18
+ - conceptual_caption/annotations/cc3m.json
19
+ images:
20
+ storage: conceptual_caption/images
lavis/configs/datasets/didemo/defaults_ret.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ didemo_retrieval: # name of the dataset builder
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: videos # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ train:
15
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/didemo/retrieval_train.json
16
+ storage: didemo/annotations/retrieval_train.json
17
+ val:
18
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/didemo/retrieval_val.json
19
+ storage: didemo/annotations/retrieval_val.json
20
+ test:
21
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/didemo/retrieval_test.json
22
+ storage: didemo/annotations/retrieval_test.json
23
+ videos:
24
+ storage: didemo/videos
25
+ # storage: /export/share/dongxuli/data/didemo_retrieval/videos
lavis/configs/datasets/flickr30k/defaults.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ flickr30k:
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: images
10
+
11
+ build_info:
12
+ annotations:
13
+ train:
14
+ url: https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json
15
+ storage: flickr30k/annotations/train.json
16
+ val:
17
+ url: https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json
18
+ storage: flickr30k/annotations/val.json
19
+ test:
20
+ url: https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json
21
+ storage: flickr30k/annotations/test.json
22
+ images:
23
+ storage: flickr30k/images
24
+ # storage: /export/share/datasets/vision/flickr30k
lavis/configs/datasets/gqa/balanced_testdev.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ gqa:
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: images # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ train:
15
+ url:
16
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/train_balanced_questions.json
17
+ storage:
18
+ - gqa/annotations/train_balanced_questions.json
19
+ val:
20
+ url:
21
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/testdev_balanced_questions.json
22
+ storage:
23
+ - gqa/annotations/testdev_balanced_questions.json
24
+ test:
25
+ url:
26
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/test_balanced_questions.json
27
+ storage:
28
+ - gqa/annotations/test_balanced_questions.json
29
+ images:
30
+ storage: gqa/images/
lavis/configs/datasets/gqa/balanced_val.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ gqa:
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: images # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ train:
15
+ url:
16
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/train_balanced_questions.json
17
+ storage:
18
+ - gqa/annotations/train_balanced_questions.json
19
+ val:
20
+ url:
21
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/val_balanced_questions.json
22
+ storage:
23
+ - gqa/annotations/val_balanced_questions.json
24
+ test:
25
+ url:
26
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/test_balanced_questions.json
27
+ storage:
28
+ - gqa/annotations/test_balanced_questions.json
29
+ images:
30
+ storage: gqa/images/
lavis/configs/datasets/gqa/defaults.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ gqa:
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: images # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ train:
15
+ url:
16
+ - /export/share/datasets/vision/GQA/questions1.2/train_all_questions/train_all_questions_0.json
17
+ - /export/share/datasets/vision/GQA/questions1.2/val_all_questions.json
18
+ storage:
19
+ - gqa/annotations/train_all_questions_0.json
20
+ - gqa/annotations/val_all_questions.json
21
+ val:
22
+ url:
23
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_val.json
24
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/large_vocab_train_lavis.json
25
+ storage:
26
+ - aokvqa/annotations/aokvqa_v1p0_val.json
27
+ - aokvqa/annotations/large_vocab_train_lavis.json
28
+ test:
29
+ url:
30
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_test.json
31
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/large_vocab_train_lavis.json
32
+ storage:
33
+ - aokvqa/annotations/aokvqa_v1p0_test.json
34
+ - aokvqa/annotations/large_vocab_train_lavis.json
35
+ images:
36
+ storage: gqa/images/
lavis/configs/datasets/imagenet/defaults.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ imagenet:
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: images # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ splits: ["val"]
14
+ images:
15
+ storage: /export/share/datasets/vision/imagenet
lavis/configs/datasets/laion/defaults_2B_multi.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ laion2B_multi:
8
+
9
+ data_type: images
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ storage: /export/laion/laion2B-multi/part-00000/{00000..01743}.tar
lavis/configs/datasets/msrvtt/defaults_cap.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ msrvtt_cap: # name of the dataset builder
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: videos # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ train:
15
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/cap_train.json
16
+ storage: msrvtt/annotations/cap_train.json
17
+ val:
18
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/cap_val.json
19
+ storage: msrvtt/annotations/cap_val.json
20
+ test:
21
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/cap_test.json
22
+ storage: msrvtt/annotations/cap_test.json
23
+ videos:
24
+ storage: msrvtt/videos
lavis/configs/datasets/msrvtt/defaults_qa.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ msrvtt_qa: # name of the dataset builder
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: videos # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ train:
15
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/qa_train.json
16
+ storage: msrvtt/annotations/qa_train.json
17
+ val:
18
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/qa_val.json
19
+ storage: msrvtt/annotations/qa_val.json
20
+ test:
21
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/qa_test.json
22
+ storage: msrvtt/annotations/qa_test.json
23
+ ans2label:
24
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/train_ans2label.json
25
+ storage: msrvtt/annotations/qa_ans2label.json
26
+ videos:
27
+ storage: msrvtt/videos
lavis/configs/datasets/msrvtt/defaults_ret.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ msrvtt_retrieval: # name of the dataset builder
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: videos # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ train:
15
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/retrieval_train.json
16
+ storage: msrvtt/annotations/retrieval_train.json
17
+ val:
18
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/retrieval_val.json
19
+ storage: msrvtt/annotations/retrieval_val.json
20
+ test:
21
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/retrieval_test.json
22
+ storage: msrvtt/annotations/retrieval_test.json
23
+ videos:
24
+ storage: msrvtt/videos
lavis/configs/datasets/msvd/defaults_cap.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ msvd_cap: # name of the dataset builder
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: videos # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ train:
15
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msvd/cap_train.json
16
+ storage: msvd/annotations/cap_train.json
17
+ val:
18
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msvd/cap_val.json
19
+ storage: msvd/annotations/cap_val.json
20
+ test:
21
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msvd/cap_test.json
22
+ storage: msvd/annotations/cap_test.json
23
+ videos:
24
+ storage: msvd/videos
lavis/configs/datasets/msvd/defaults_qa.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ msvd_qa: # name of the dataset builder
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: videos # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ train:
15
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msvd/qa_train.json
16
+ storage: msvd/annotations/qa_train.json
17
+ val:
18
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msvd/qa_val.json
19
+ storage: msvd/annotations/qa_val.json
20
+ test:
21
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msvd/qa_test.json
22
+ storage: msvd/annotations/qa_test.json
23
+ ans2label:
24
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msvd/train_ans2label.json
25
+ storage: msvd/annotations/qa_ans2label.json
26
+ videos:
27
+ storage: msvd/videos
28
+
29
+ instance_id_key: question_id
lavis/configs/datasets/nlvr/defaults.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ nlvr:
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: images # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ train:
15
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/nlvr/nlvr_train.json
16
+ storage: nlvr/annotations/train.json
17
+ val:
18
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/nlvr/nlvr_dev.json
19
+ storage: nlvr/annotations/dev.json
20
+ test:
21
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/nlvr/nlvr_dev.json
22
+ storage: nlvr/annotations/test.json
23
+ images:
24
+ storage: /export/share/datasets/vision/NLVR2/
lavis/configs/datasets/nocaps/defaults.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ nocaps: # name of the dataset builder
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: images # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ val:
15
+ url: https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json
16
+ storage: nocaps/annotations/nocaps_val.json
17
+ test:
18
+ url: https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json
19
+ storage: nocaps/annotations/nocaps_test.json
20
+ images:
21
+ storage: nocaps/images
22
+ # storage: /export/share/datasets/vision/nocaps/
lavis/configs/datasets/okvqa/defaults.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ ok_vqa:
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: images # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ train:
15
+ url:
16
+ # TODO make this order insensitive
17
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_train.json
18
+ # - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/OpenEnded_mscoco_train2014_questions.json
19
+ # - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_train2014_annotations.json
20
+ storage:
21
+ - okvqa/annotations/okvqa_train.json
22
+ # - okvqa/annotations/OpenEnded_mscoco_train2014_questions.json
23
+ # - okvqa/annotations/mscoco_train2014_annotations.json
24
+ test:
25
+ url:
26
+ # TODO make this order insensitive
27
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_val_eval.json
28
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_answer_list_train.json
29
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/OpenEnded_mscoco_val2014_questions.json
30
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_val2014_annotations.json
31
+ storage:
32
+ - okvqa/annotations/vqa_val_eval.json
33
+ - okvqa/annotations/answer_list.json
34
+ - okvqa/annotations/OpenEnded_mscoco_val2014_questions.json
35
+ - okvqa/annotations/mscoco_val2014_annotations.json
36
+ images:
37
+ storage: coco/images/
lavis/configs/datasets/sbu_caption/defaults.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ sbu_caption:
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: images # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ train:
15
+ url:
16
+ - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/sbu/sbu.json
17
+ # - /export/share/dongxuli/data/lavis/sbu/annotation/sbu.json
18
+ storage:
19
+ - sbu_captions/annotations/sbu.json
20
+ images:
21
+ storage: sbu_captions/images
22
+ # storage: /export/share/datasets/vision_language/sbu_resize
lavis/configs/datasets/snli_ve/defaults.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ snli_ve:
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: images # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ train:
15
+ url: /export/share/dongxuli/data/lavis/snli/annotation/ve_train.json
16
+ storage: snli/annotations/ve_train.json
17
+ val:
18
+ url: /export/share/dongxuli/data/lavis/snli/annotation/ve_dev.json
19
+ storage: snli/annotations/ve_dev.json
20
+ test:
21
+ url: /export/share/dongxuli/data/lavis/snli/annotation/ve_test.json
22
+ storage: snli/annotations/ve_test.json
23
+ images:
24
+ storage: flickr30k/images/flickr30k-images
25
+ # storage: /export/share/datasets/vision/flickr30k/flickr30k-images
lavis/configs/datasets/vatex/defaults_cap.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ msvd_cap: # name of the dataset builder
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: videos # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ train:
15
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vatex/cap_train.json
16
+ storage: vatex/annotations/cap_train.json
17
+ val:
18
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vatex/cap_val.json
19
+ storage: vatex/annotations/cap_val.json
20
+ test:
21
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vatex/cap_private_test.json
22
+ storage: vatex/annotations/cap_test.json
23
+ videos:
24
+ storage: /export/share/dongxuli/data/vatex
lavis/configs/datasets/vg/defaults_caption.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ vg_caption:
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: images # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ train:
15
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/visual_genome/vg_caption.json
16
+ storage: vg/annotations/vg_caption.json
17
+ images:
18
+ storage: vg/images/
lavis/configs/datasets/vg/defaults_vqa.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ datasets:
7
+ vg_vqa:
8
+ # data_dir: ${env.data_dir}/datasets
9
+ data_type: images # [images|videos|features]
10
+
11
+ build_info:
12
+ # Be careful not to append minus sign (-) before split to avoid itemizing
13
+ annotations:
14
+ train:
15
+ url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/visual_genome/vg_qa.json
16
+ storage: vg/annotations/vg_qa.json
17
+ images:
18
+ storage: vg/images/
lavis/configs/default.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ env:
7
+ # For default users
8
+ # cache_root: "cache"
9
+ # For internal use with persistent storage
10
+ cache_root: "/export/home/.cache/lavis"
lavis/configs/models/albef_classification_ve.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ model:
7
+ arch: albef_classification
8
+ load_finetuned: True
9
+
10
+ finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALBEF/albef_snli_ve_lavis.pt"
11
+ pretrained: "https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF.pth"
12
+
13
+ num_classes: 3
14
+
15
+ use_distill: True
16
+ momentum: 0.995
17
+ alpha: 0.4
18
+
19
+ # vit encoder
20
+ vit_type: "base"
21
+ vit_grad_ckpt: False
22
+ vit_ckpt_layer: 0
23
+ vit_layer_norm_epsilon: 1e-6
24
+
25
+ image_size: 384
26
+
27
+ # bert config
28
+ med_config_path: "configs/models/med_config_albef.json"
29
+
30
+ preprocess:
31
+ vis_processor:
32
+ train:
33
+ name: "blip_image_train"
34
+ eval:
35
+ name: "blip_image_eval"
36
+ text_processor:
37
+ train:
38
+ name: "blip_caption"
39
+ eval:
40
+ name: "blip_caption"
lavis/configs/models/albef_feature_extractor.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ model:
7
+ arch: albef_pretrain
8
+ pretrained: "https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF.pth"
9
+
10
+ # vit encoder
11
+ vit_type: "base"
12
+ image_size: 224
13
+ vit_ckpt_layer: 0
14
+ vit_drop_path_rate: 0
15
+ vit_layer_norm_epsilon: 1e-6
16
+ vit_grad_ckpt: False
17
+
18
+ # bert config
19
+ med_config_path: "configs/models/med_config_albef.json"
20
+
21
+ embed_dim: 256
22
+
23
+ preprocess:
24
+ vis_processor:
25
+ eval:
26
+ name: "blip_image_eval"
27
+ image_size: 224
28
+ text_processor:
29
+ eval:
30
+ name: "blip_caption"
lavis/configs/models/albef_nlvr.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ model:
7
+ arch: albef_nlvr
8
+ load_finetuned: True
9
+
10
+ pretrained: "https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/pretrain_model_nlvr.pth"
11
+ finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALBEF/albef_nlvr_lavis.pt"
12
+
13
+ num_classes: 2
14
+
15
+ use_distill: True
16
+ momentum: 0.995
17
+ alpha: 0.4
18
+
19
+ # vit encoder
20
+ vit_type: "base"
21
+ vit_grad_ckpt: False
22
+ vit_ckpt_layer: 0
23
+ vit_layer_norm_epsilon: 1e-6
24
+
25
+ image_size: 384
26
+
27
+ # bert config
28
+ med_config_path: "configs/models/med_config_albef.json"
29
+
30
+ preprocess:
31
+ vis_processor:
32
+ train:
33
+ name: "blip_image_train"
34
+ image_size: 384
35
+ eval:
36
+ name: "blip_image_eval"
37
+ image_size: 384
38
+ text_processor:
39
+ train:
40
+ name: "blip_caption"
41
+ eval:
42
+ name: "blip_caption"
lavis/configs/models/albef_pretrain_base.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ model:
7
+ arch: albef_pretrain
8
+
9
+ load_pretrained: True
10
+ pretrained: "https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF.pth"
11
+
12
+ # vit encoder
13
+ vit_type: "base"
14
+ image_size: 224
15
+ vit_ckpt_layer: 0
16
+ vit_drop_path_rate: 0
17
+ vit_layer_norm_epsilon: 1e-6
18
+ vit_grad_ckpt: False
19
+
20
+ # bert config
21
+ med_config_path: "configs/models/med_config_albef.json"
22
+ mlm_mask_prob: 0.15
23
+
24
+ embed_dim: 256
25
+ momentum: 0.995
26
+ alpha: 0.4
27
+ temp: 0.07
28
+
29
+ max_txt_len: 30
30
+
31
+ preprocess:
32
+ vis_processor:
33
+ train:
34
+ name: "blip_image_train"
35
+ image_size: 256
36
+ text_processor:
37
+ train:
38
+ name: "blip_caption"
lavis/configs/models/albef_retrieval_coco.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ model:
7
+ arch: albef_retrieval
8
+ load_finetuned: True
9
+
10
+ pretrained: "https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF.pth"
11
+ finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALBEF/albef_coco_retrieval_lavis.pt"
12
+
13
+ queue_size: 65536
14
+
15
+ # vit encoder
16
+ vit_type: "base"
17
+ image_size: 384
18
+ vit_ckpt_layer: 0
19
+ vit_drop_path_rate: 0
20
+ vit_layer_norm_epsilon: 1e-6
21
+ vit_grad_ckpt: False
22
+
23
+ # bert config
24
+ med_config_path: "configs/models/med_config_albef.json"
25
+
26
+ embed_dim: 256
27
+ momentum: 0.995
28
+ alpha: 0.4
29
+ temp: 0.07
30
+ use_distill: True
31
+
32
+ max_txt_len: 30
33
+
34
+ preprocess:
35
+ vis_processor:
36
+ train:
37
+ name: "blip_image_train"
38
+ image_size: 384
39
+ eval:
40
+ name: "blip_image_eval"
41
+ image_size: 384
42
+ text_processor:
43
+ train:
44
+ name: "blip_caption"
45
+ eval:
46
+ name: "blip_caption"
lavis/configs/models/albef_retrieval_flickr.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ model:
7
+ arch: albef_retrieval
8
+ load_finetuned: True
9
+
10
+ pretrained: "https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF.pth"
11
+ finetuned: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALBEF/albef_flickr_retrieval_lavis.pt
12
+
13
+ queue_size: 65536
14
+
15
+ # vit encoder
16
+ vit_type: "base"
17
+ image_size: 384
18
+ vit_ckpt_layer: 0
19
+ vit_drop_path_rate: 0
20
+ vit_layer_norm_epsilon: 1e-6
21
+ vit_grad_ckpt: False
22
+
23
+ # bert config
24
+ med_config_path: "configs/models/med_config_albef.json"
25
+
26
+ embed_dim: 256
27
+ momentum: 0.995
28
+ alpha: 0.4
29
+ temp: 0.07
30
+ use_distill: True
31
+
32
+ max_txt_len: 30
33
+
34
+ preprocess:
35
+ vis_processor:
36
+ train:
37
+ name: "blip_image_train"
38
+ image_size: 384
39
+ eval:
40
+ name: "blip_image_eval"
41
+ image_size: 384
42
+ text_processor:
43
+ train:
44
+ name: "blip_caption"
45
+ eval:
46
+ name: "blip_caption"
lavis/configs/models/albef_vqav2.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ model:
7
+ arch: albef_vqa
8
+ load_finetuned: True
9
+
10
+ pretrained: "https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF.pth"
11
+ finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALBEF/albef_vqav2_lavis.pt"
12
+
13
+ use_distill: True
14
+ momentum: 0.995
15
+ alpha: 0.4
16
+
17
+ # vit encoder
18
+ vit_type: "base"
19
+ vit_grad_ckpt: False
20
+ vit_ckpt_layer: 0
21
+ vit_layer_norm_epsilon: 1e-6
22
+
23
+ image_size: 384
24
+
25
+ # bert config
26
+ med_config_path: "configs/models/med_config_albef.json"
27
+
28
+ preprocess:
29
+ vis_processor:
30
+ train:
31
+ name: "blip_image_train"
32
+ image_size: 384
33
+ eval:
34
+ name: "blip_image_eval"
35
+ image_size: 384
36
+ text_processor:
37
+ train:
38
+ name: "blip_question"
39
+ eval:
40
+ name: "blip_question"
lavis/configs/models/alpro_qa_msrvtt.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ model:
7
+ arch: alpro_qa
8
+ num_classes: 1500
9
+
10
+ load_finetuned: True
11
+
12
+ finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALPRO/alpro_msrvtt_qa.pth"
13
+ pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALPRO/alpro_pretrain.pt"
14
+
15
+ timesformer:
16
+ n_frms: 16
17
+ image_size: 224
18
+
19
+ patch_size: 16
20
+ attn_drop_rate: 0.
21
+ drop_rate: 0.
22
+ drop_path_rate: 0.1
23
+
24
+ use_grad_ckpt: True
25
+ ckpt_layer: 12
26
+
27
+ # bert config
28
+ med_config_path: "configs/models/bert_config_alpro.json"
29
+
30
+ preprocess:
31
+ vis_processor:
32
+ train:
33
+ name: "alpro_video_train"
34
+ n_frms: 16
35
+ image_size: 224
36
+ eval:
37
+ name: "alpro_video_eval"
38
+ n_frms: 16
39
+ image_size: 224
40
+ text_processor:
41
+ train:
42
+ name: "blip_caption"
43
+ eval:
44
+ name: "blip_caption"
lavis/configs/models/alpro_qa_msvd.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, salesforce.com, inc.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5
+
6
+ model:
7
+ arch: alpro_qa
8
+ num_classes: 2423
9
+
10
+ load_finetuned: True
11
+
12
+ finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALPRO/alpro_msvd_qa.pth"
13
+ pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALPRO/alpro_pretrain.pt"
14
+
15
+ timesformer:
16
+ n_frms: 16
17
+ image_size: 224
18
+
19
+ patch_size: 16
20
+ attn_drop_rate: 0.
21
+ drop_rate: 0.
22
+ drop_path_rate: 0.1
23
+ use_grad_ckpt: True
24
+ ckpt_layer: 12
25
+
26
+ # bert config
27
+ med_config_path: "configs/models/bert_config_alpro.json"
28
+
29
+ preprocess:
30
+ vis_processor:
31
+ train:
32
+ name: "alpro_video_train"
33
+ n_frms: 16
34
+ image_size: 224
35
+ eval:
36
+ name: "alpro_video_eval"
37
+ n_frms: 16
38
+ image_size: 224
39
+ text_processor:
40
+ train:
41
+ name: "blip_caption"
42
+ eval:
43
+ name: "blip_caption"