stellaathena commited on
Commit
bb5cd12
·
1 Parent(s): 23ee17f

This should work

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Aleph Alpha GmbH
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import re
3
+ from magma import Magma
4
+ from magma.image_input import ImageInput
5
+
6
+ model = Magma.from_checkpoint(
7
+ config_path = "configs/MAGMA_v1.yml",
8
+ checkpoint_path = "./mp_rank_00_model_states.pt",
9
+ device = 'cuda:0'
10
+ )
11
+
12
+ def generate(context, length, temperature, top_k):
13
+ context = context.strip()
14
+
15
+ url_regex = r'https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)'
16
+ lines = context.split('\n')
17
+ inputs = []
18
+ for line in lines:
19
+ if re.match(url_regex, line):
20
+ try:
21
+ inputs.append(ImageInput(line))
22
+ except Exception as e:
23
+ return str(e)
24
+ else:
25
+ inputs.append(line)
26
+
27
+ ## returns a tensor of shape: (1, 149, 4096)
28
+ embeddings = model.preprocess_inputs(inputs)
29
+
30
+ ## returns a list of length embeddings.shape[0] (batch size)
31
+ output = model.generate(
32
+ embeddings = embeddings,
33
+ max_steps = length,
34
+ temperature = (0.01 if temperature == 0 else temperature),
35
+ top_k = top_k
36
+ )
37
+
38
+ return context + output[0]
39
+
40
+ iface = gr.Interface(
41
+ fn=generate,
42
+ inputs=[
43
+ gr.inputs.Textbox(
44
+ label="Prompt (image URLs need to be on their own lines):",
45
+ default="https://www.art-prints-on-demand.com/kunst/thomas_cole/woods_hi.jpg\nDescribe the painting:",
46
+ lines=7),
47
+ gr.inputs.Slider(minimum=1, maximum=100, default=15, step=1, label="Output tokens:"),
48
+ gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.7, label='Temperature'),
49
+ gr.inputs.Slider(minimum=0, maximum=100, default=0, step=1, label='Top K')
50
+ ],
51
+ outputs=["textbox"]
52
+ ).launch(share=True)
53
+
54
+
configs/MAGMA_v1.yml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ # image encoder settings
3
+ encoder_name: 'clip_resnet_large',
4
+ adapter_config: {"mlp": {"adapter_type": "normal", "downsample_factor": 4}},
5
+ freeze_img_encoder: false,
6
+
7
+ # train settings
8
+ batch_size: 256,
9
+ train_steps: 150000,
10
+ lr: 8.0e-4,
11
+ min_lr: 0.0,
12
+ lr_decay_iters: 300000,
13
+ image_enc_lr: 2.0e-6,
14
+ use_image_embed_layernorm: true,
15
+ image_embed_dropout_prob: 0.1,
16
+ image_size: 384,
17
+
18
+ gradient_accumulation_steps: 8,
19
+ zero_stage: 2,
20
+ gradient_clipping: 1.0,
21
+
22
+ # dataset / save / load settings
23
+ train_dataset_name: 'conceptual_captions',
24
+ train_dataset_dir: '/mnt/localdisk/conceptual_captions',
25
+ eval_dataset_name: 'coco',
26
+ eval_dataset_dir: '/mnt/localdisk/coco_data',
27
+
28
+ save: "/mnt/shared_vol/checkpoints/multimodal_transformer_rn50x16",
29
+ load: "/mnt/shared_vol/checkpoints/multimodal_transformer_rn50x16",
30
+
31
+ eval_every: 100,
32
+
33
+ }
configs/MAGMA_v2.yml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ # image encoder settings
3
+ encoder_name: 'clip_resnet_large',
4
+ adapter_config: {"mlp": {"adapter_type": "normal", "downsample_factor": 8}, "attention": {"adapter_type": "normal", "downsample_factor": 8}},
5
+ freeze_img_encoder: false,
6
+
7
+ # train settings
8
+ batch_size: 256,
9
+ train_steps: 150000,
10
+ lr: 8.0e-4,
11
+ min_lr: 0.0,
12
+ lr_decay_iters: 300000,
13
+ image_enc_lr: 2.0e-6,
14
+ use_image_embed_layernorm: true,
15
+ image_embed_dropout_prob: 0.1,
16
+ image_size: 384,
17
+
18
+ gradient_accumulation_steps: 4,
19
+ zero_stage: 2,
20
+ gradient_clipping: 1.0,
21
+
22
+ # dataset / save / load settings
23
+ dataset_type: 'new',
24
+ train_dataset_dir: ['/mnt/localdisk/laion', '/mnt/brick/CC3M_converted', '/mnt/localdisk/localized_narratives', '/mnt/localdisk/visual_genome_converted', '/mnt/localdisk/hateful_memes_converted', '/mnt/localdisk/coco_converted', '/mnt/brick/wit_converted', '/mnt/localdisk/gqa_train_converted', '/mnt/localdisk/vqa_train_converted', '/mnt/localdisk/okvqa_train_converted'], #'/mnt/brick/wit_converted'
25
+
26
+ eval_dataset_dir: null, # if this is none, train dataset will be split
27
+ vqa_dir: "/mnt/localdisk/vqa_val_converted",
28
+ gqa_dir: "/mnt/localdisk/gqa_val_converted",
29
+
30
+ save: "/mnt/shared_vol/checkpoints/MAGMA_RN50x16",
31
+ load: "/mnt/shared_vol/checkpoints/MAGMA_RN50x16",
32
+
33
+ eval_every: 250,
34
+ wandb_project: "MAGMA_training",
35
+ name: "MAGMA_RN50x16_v1"
36
+ }
example_inference.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from magma import Magma
2
+ from magma.image_input import ImageInput
3
+
4
+ model = Magma.from_checkpoint(
5
+ config_path = "configs/MAGMA_v1.yml",
6
+ checkpoint_path = "./mp_rank_00_model_states.pt",
7
+ device = 'cuda:0'
8
+ )
9
+
10
+ inputs =[
11
+ ## supports urls and path/to/image
12
+ ImageInput('https://www.art-prints-on-demand.com/kunst/thomas_cole/woods_hi.jpg'),
13
+ 'Describe the painting:'
14
+ ]
15
+
16
+ ## returns a tensor of shape: (1, 149, 4096)
17
+ embeddings = model.preprocess_inputs(inputs)
18
+
19
+ ## returns a list of length embeddings.shape[0] (batch size)
20
+ output = model.generate(
21
+ embeddings = embeddings,
22
+ max_steps = 6,
23
+ temperature = 0.7,
24
+ top_k = 0,
25
+ )
26
+
27
+ print(output[0]) ## A cabin on a lake
examples/magma_oracle.png ADDED
examples/magma_present.jpg ADDED
examples/magma_social.png ADDED
examples/magma_treasure.png ADDED
examples/magma_tree.jpg ADDED
examples/model.jpg ADDED
magma/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .config import MultimodalConfig
2
+ from .magma import Magma
3
+ from .language_model import get_gptj
4
+ from .transforms import get_transforms
5
+ from .utils import (
6
+ count_parameters,
7
+ is_main,
8
+ cycle,
9
+ get_tokenizer,
10
+ parse_args,
11
+ wandb_log,
12
+ wandb_init,
13
+ save_model,
14
+ load_model,
15
+ print_main,
16
+ configure_param_groups,
17
+ log_table,
18
+ )
19
+ from .train_loop import eval_step, inference_step, train_step
20
+ from .datasets import collate_fn
magma/adapters.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchtyping import TensorType
4
+
5
+
6
+ class Adapter(nn.Module):
7
+ def __init__(
8
+ self,
9
+ dim: int,
10
+ downsample_factor: int = 4,
11
+ activation: nn.Module = nn.ReLU,
12
+ add_layernorm: bool = False,
13
+ ):
14
+ super().__init__()
15
+ layers = []
16
+ if add_layernorm:
17
+ layers.append(nn.LayerNorm(dim))
18
+ layers.extend(
19
+ [
20
+ nn.Linear(dim, dim // downsample_factor),
21
+ activation(),
22
+ nn.Linear(dim // downsample_factor, dim),
23
+ ]
24
+ )
25
+ self.adapter = nn.Sequential(*layers)
26
+ self.adapter.apply(self.init_weights)
27
+
28
+ def init_weights(self, m: nn.Module, std=1e-3):
29
+ if isinstance(m, nn.Linear):
30
+ torch.nn.init.normal_(m.weight, std=std)
31
+ torch.nn.init.normal_(m.bias, std=std)
32
+ m.weight.data = torch.clamp(m.weight.data, min=-2 * std, max=2 * std)
33
+ m.bias.data = torch.clamp(m.bias.data, min=-2 * std, max=2 * std)
34
+ elif isinstance(m, nn.LayerNorm):
35
+ m.bias.data.zero_()
36
+ m.weight.data.fill_(1.0)
37
+
38
+ def forward(self, x: TensorType["b", "s", "d"]) -> TensorType["b", "s", "d"]:
39
+ return self.adapter(x) + x
40
+
41
+
42
+ class ParallelAdapter(Adapter):
43
+ def __init__(
44
+ self,
45
+ module: nn.Module,
46
+ dim: int,
47
+ downsample_factor: int = 4,
48
+ scaled: bool = False,
49
+ add_layernorm: bool = False,
50
+ activation: nn.Module = nn.ReLU,
51
+ ):
52
+ super().__init__(
53
+ dim, downsample_factor, add_layernorm=add_layernorm, activation=activation
54
+ )
55
+ self.module = module
56
+
57
+ if scaled:
58
+ # init scaling param
59
+ self.adapter_scale = nn.Parameter(torch.ones(1))
60
+ else:
61
+ self.adapter_scale = 1
62
+
63
+ def forward(self, x: TensorType["b", "s", "d"], **module_kwargs):
64
+ y = self.module(x, **module_kwargs)
65
+ z = self.adapter(x)
66
+ return y + (z * self.adapter_scale)
67
+
68
+
69
+ class ParallelAdapterWrapper(ParallelAdapter):
70
+ # used to add an adapter to the attention block
71
+
72
+ def __init__(
73
+ self,
74
+ module: nn.Module,
75
+ dim: int,
76
+ downsample_factor: int = 4,
77
+ scaled: bool = False,
78
+ add_layernorm: bool = False,
79
+ activation: nn.Module = nn.ReLU,
80
+ ):
81
+ super().__init__(
82
+ module, dim, downsample_factor, scaled, add_layernorm, activation
83
+ )
84
+
85
+ def forward(self, x: TensorType["b", "s", "d"], *attn_args, **attn_kwargs):
86
+ attn_outputs = self.module(x, *attn_args, **attn_kwargs)
87
+ attn_output, outputs = (
88
+ attn_outputs[0],
89
+ attn_outputs[1:],
90
+ ) # output_attn: a, present, (attentions)
91
+ hidden_states = attn_output + (self.adapter(x) * self.adapter_scale)
92
+ return (hidden_states,) + outputs
93
+
94
+
95
+ class AdapterWrapper(Adapter):
96
+ # used to add an adapter to the attention block
97
+
98
+ def __init__(
99
+ self,
100
+ attn_block: nn.Module,
101
+ dim: int,
102
+ downsample_factor: int = 4,
103
+ activation: nn.Module = nn.ReLU,
104
+ add_layernorm: bool = False,
105
+ ):
106
+ super().__init__(dim, downsample_factor, activation, add_layernorm)
107
+ self.attn_block = attn_block
108
+
109
+ def forward(self, x: TensorType["b", "s", "d"], *attn_args, **attn_kwargs):
110
+ attn_outputs = self.attn_block(x, *attn_args, **attn_kwargs)
111
+ attn_output, outputs = (
112
+ attn_outputs[0],
113
+ attn_outputs[1:],
114
+ ) # output_attn: a, present, (attentions)
115
+ hidden_states = self.adapter(attn_output) + attn_output
116
+ return (hidden_states,) + outputs
magma/config.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, asdict
2
+ import yaml
3
+ from pprint import pprint
4
+ from .utils import is_main
5
+ import os
6
+ from pathlib import Path
7
+ import uuid
8
+
9
+
10
+ def load_config(path, config_dir=Path("configs")):
11
+ if not path.endswith(".yml"):
12
+ path += ".yml"
13
+ if not os.path.exists(path):
14
+ path = config_dir / path
15
+ with open(path, "r") as stream:
16
+ config = yaml.safe_load(stream)
17
+ return config
18
+
19
+
20
+ @dataclass
21
+ class MultimodalConfig:
22
+
23
+ # Training:
24
+ # ------------------------------------------------------------
25
+
26
+ batch_size: int
27
+ train_steps: int
28
+ optimizer_name: str = "AdamW"
29
+ lr: float = 8.0e-4
30
+ image_enc_lr: float = None
31
+ min_lr: float = 0.0
32
+ lr_decay_iters: int = None
33
+ gradient_accumulation_steps: int = 1
34
+ image_size: int = 256
35
+ eval_every: int = 250
36
+ eval_steps: int = 25
37
+ zero_stage: int = 2
38
+ gradient_clipping: float = 1.0
39
+ warmup_num_steps: int = 100
40
+ weight_decay: float = 0.00
41
+ run_blind: bool = False
42
+ fine_tune: bool = False
43
+ load_optimizer: bool = True
44
+
45
+ # Checkpointing:
46
+ # ------------------------------------------------------------
47
+ save_every: int = 2500
48
+ save: str = None
49
+ load: str = None
50
+
51
+ # Data:
52
+ # ------------------------------------------------------------
53
+ train_dataset_name: str = "conceptual_captions"
54
+ eval_dataset_name: str = "/data/conceptual_captions"
55
+ train_dataset_dir: str = "/data/coco_data"
56
+ eval_dataset_dir: str = "/data/coco_data"
57
+ eval_dataset_pct: float = 0.1
58
+
59
+ # Model architecture:
60
+ # ------------------------------------------------------------
61
+ encoder_name: str = "clip"
62
+ tokenizer_name: str = "gpt2"
63
+ lm_name: str = "EleutherAI/gpt-j-6B"
64
+ image_seq_len: int = 2
65
+ pretrained_img_encoder: bool = False
66
+ seq_len: int = None
67
+
68
+ # Layer Freezing settings:
69
+ # ------------------------------------------------------------
70
+ freeze_lm: bool = True
71
+ freeze_img_encoder: bool = True
72
+
73
+ image_embed_dropout_prob: float = 0.0
74
+ use_image_embed_layernorm: bool = False
75
+
76
+ # Adapter settings:
77
+ # ------------------------------------------------------------
78
+ adapter_config: dict = None
79
+
80
+ # Classification Finetuning settings:
81
+ # ------------------------------------------------------------
82
+ class_dict: dict = None # {num_classes: .., ckpt_path: .., classifier_type:, .., interface_type: .., interface_position: .., freeze_model: ..}
83
+
84
+ # Logging settings:
85
+ # ------------------------------------------------------------
86
+ name: str = None # name, just used for wandb logging
87
+ log_every: int = 1
88
+ wandb_project: str = "magma"
89
+
90
+ def print(self):
91
+ if is_main():
92
+ print("-" * 100)
93
+ pprint(self.__dict__, indent=4)
94
+ print("-" * 100)
95
+
96
+ def __post_init__(self):
97
+ self.is_classifier = self.class_dict is not None
98
+ if self.adapter_config is None:
99
+ self.adapter_config = {}
100
+
101
+ # Deepspeed Settings:
102
+ # ------------------------------------------------------------
103
+ if self.lr_decay_iters is None:
104
+ self.lr_scheduler = "WarmupLR"
105
+ self.scheduler_dict = {
106
+ "type": self.lr_scheduler,
107
+ "params": {
108
+ "warmup_min_lr": self.min_lr,
109
+ "warmup_max_lr": self.lr,
110
+ "warmup_num_steps": self.warmup_num_steps,
111
+ },
112
+ }
113
+ else:
114
+ self.lr_scheduler = "WarmupDecayLR"
115
+ self.scheduler_dict = {
116
+ "type": self.lr_scheduler,
117
+ "params": {
118
+ "total_num_steps": self.lr_decay_iters,
119
+ "warmup_min_lr": self.min_lr,
120
+ "warmup_max_lr": self.lr,
121
+ "warmup_num_steps": self.warmup_num_steps,
122
+ },
123
+ }
124
+ self.deepspeed_config_params = {
125
+ "train_batch_size": self.batch_size,
126
+ "gradient_accumulation_steps": self.gradient_accumulation_steps,
127
+ "gradient_clipping": self.gradient_clipping,
128
+ "fp16": {"enabled": True, "loss_scale_window": 250},
129
+ "scheduler": self.scheduler_dict,
130
+ "zero_optimization": {
131
+ "stage": self.zero_stage,
132
+ "load_from_fp32_weights": False,
133
+ },
134
+ }
135
+
136
+ if self.name is None:
137
+ self.name = str(uuid.uuid4())[:8]
138
+
139
+ @classmethod
140
+ def from_yml(cls, path):
141
+ return cls(**load_config(path))
142
+
143
+ def to_dict(self):
144
+ return asdict(self)
magma/datasets/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .dataset import (
2
+ ImgCptDataset,
3
+ collate_fn,
4
+ )
5
+
magma/datasets/convert_datasets.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from PIL import UnidentifiedImageError
3
+ import os
4
+ import json
5
+ from pathlib import Path
6
+ from tqdm import tqdm
7
+ import shutil
8
+
9
+
10
+ def save_to_jsons(data_list, target_dir, starting_idx=0):
11
+ pbar = tqdm(
12
+ enumerate(data_list), desc=f"saving {len(data_list)} jsons to {str(target_dir)}"
13
+ )
14
+ for k, data in pbar:
15
+ filename = Path(target_dir) / Path(f"{k+starting_idx}.json")
16
+ with open(filename, "w") as f:
17
+ json.dump(data, f)
18
+
19
+ return None
20
+
21
+
22
+ def save_images(img_list, target_dir, mode="mv"):
23
+ for img_path in tqdm(
24
+ img_list,
25
+ desc=f"saving {len(img_list)} images (mode={mode}) to {str(target_dir)}",
26
+ ):
27
+ if mode == "mv":
28
+ shutil.move(img_path, target_dir)
29
+ elif mode == "cp":
30
+ shutil.copy(img_path, target_dir)
31
+
32
+
33
+ def convert_dataset(
34
+ data_dir,
35
+ dir_size=10000,
36
+ hash_fn=None,
37
+ mode="mv",
38
+ ds_iterator=None,
39
+ ):
40
+ """
41
+ Builds a dataset directory in our standard format. ds_iterator should return data of the form
42
+ image_path, {"captions": [...], "metadata": {...}, }, where image_path should be a Path object, captions should map to a list of strings
43
+ and metadata can contain any custom data about the image. If a hash_fn is specified (such as phash), the image hash gets saved in metadata.
44
+ """
45
+
46
+ data_dir = Path(data_dir)
47
+
48
+ # folders for images and corresponding data which is stored in a json file for each image
49
+ os.makedirs(data_dir / "images", exist_ok=True)
50
+ os.makedirs(data_dir / "image_data", exist_ok=True)
51
+
52
+ img_data_list = []
53
+ img_path_list = []
54
+ save_img_dir = data_dir / "images" / "0"
55
+ save_data_dir = data_dir / "image_data" / "0"
56
+ num_img_dirs = 0
57
+
58
+ # save the new locations of all img files in case some datafiles point to the same image
59
+ new_img_locations = {}
60
+
61
+ pbar = tqdm(
62
+ enumerate(ds_iterator),
63
+ desc="converting dataset to standard format...",
64
+ )
65
+
66
+ for k, (img_path, data) in pbar:
67
+ img_cpt_data = {}
68
+ # get img data
69
+ img_cpt_data.update(data)
70
+
71
+ if str(img_path) in new_img_locations.keys():
72
+ # if filename is in the dictionary, it already has a new location
73
+ new_img_path = new_img_locations[str(img_path)]["new_img_path"]
74
+ img_cpt_data["image_path"] = new_img_path
75
+ if hash_fn is not None:
76
+ img_cpt_data["metadata"]["image_hash"] = new_img_locations[
77
+ str(img_path)
78
+ ]["hash"]
79
+ else:
80
+ # if file exists in the old location, it will get moved to a new directory
81
+ new_img_path = f"images/{save_img_dir.name}/{img_path.name}"
82
+ img_cpt_data["image_path"] = new_img_path
83
+ new_img_locations[str(img_path)] = {"new_img_path": new_img_path}
84
+ # original location is saved an later saved to the new directory
85
+ img_path_list.append(img_path)
86
+
87
+ # if given, apply hash fn
88
+ if hash_fn is not None:
89
+ try:
90
+ img = Image.open(img_path).convert("RGB")
91
+ hash_str = str(hash_fn(img))
92
+ img_cpt_data["metadata"]["image_hash"] = hash_str
93
+ # save hash so it does not have to be recomputed
94
+ new_img_locations[str(img_path)]["hash"] = hash_str
95
+ except (UnidentifiedImageError, FileNotFoundError):
96
+ print("Warning: corrupted or non-existent Image")
97
+
98
+ img_data_list.append(img_cpt_data)
99
+
100
+ # save images in specified images folder (maximum of dir_size images per folder)
101
+ if (len(img_path_list) % dir_size == 0 and len(img_path_list) > 0) or (
102
+ k == len(ds_iterator) - 1
103
+ ):
104
+ os.makedirs(save_img_dir, exist_ok=True)
105
+ save_images(img_path_list, save_img_dir, mode=mode)
106
+ img_path_list = []
107
+ num_img_dirs += 1
108
+ save_img_dir = data_dir / "images" / f"{num_img_dirs}/"
109
+
110
+ # save jdon data in specified image_data folder with consecutive labeling of the json files
111
+ if ((k + 1) % dir_size == 0) or (k == len(ds_iterator) - 1):
112
+ os.makedirs(save_data_dir, exist_ok=True)
113
+ save_to_jsons(
114
+ img_data_list, save_data_dir, starting_idx=max(k + 1 - dir_size, 0)
115
+ )
116
+ # empty path and data lists and update save directories for next saving step
117
+ img_data_list = []
118
+ save_data_dir = data_dir / "image_data" / f"{int((k+1)/dir_size)}/"
magma/datasets/dataset.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ from PIL import Image
4
+ from PIL.Image import Image as img
5
+ from PIL.Image import DecompressionBombError
6
+ from PIL import UnidentifiedImageError
7
+ import json
8
+ from pathlib import Path
9
+
10
+ from tqdm import tqdm
11
+ from typing import List, Tuple, Generator
12
+ import random
13
+ from multiprocessing import Pool, cpu_count
14
+
15
+ from PIL import Image
16
+ from torch.utils.data import Dataset
17
+ from typing import Tuple
18
+ from torchtyping import TensorType
19
+ import traceback
20
+
21
+
22
+ def read_jsonl(filename: str) -> Generator[List, None, None]:
23
+ """
24
+ Iterator over data from a jsonl file
25
+ """
26
+ with open(filename) as file:
27
+ for line in file:
28
+ yield json.loads(line.rstrip("\n|\r"))
29
+
30
+
31
+ def read_img_captions(filename: str) -> List[Tuple[str, str]]:
32
+ """
33
+ Yields image_path, image_caption from cc jsonl files
34
+ """
35
+ img_captions = []
36
+ for item in read_jsonl(filename):
37
+ if not "N/A" in item[-2:]:
38
+ img_captions.append((item[-1], item[-2]))
39
+ return img_captions
40
+
41
+
42
+ def load_json(filename):
43
+ try:
44
+ with open(filename) as f:
45
+ return json.load(f)
46
+ except Exception:
47
+ print(f"ERROR: Error loading json file {filename}")
48
+ traceback.print_exc()
49
+
50
+
51
+ def _read_image_data(data_dir):
52
+ image_data = []
53
+ img_data_dir = data_dir / "image_data"
54
+ paths = _load_paths(data_dir)
55
+ pbar = tqdm(
56
+ paths,
57
+ desc=f"loading dataset from {str(data_dir)}",
58
+ )
59
+ # read data with multiprocessing
60
+ with Pool(cpu_count()) as pool:
61
+ for img_data in pool.imap(load_json, pbar):
62
+ if img_data is not None:
63
+ image_data.append(img_data)
64
+ return image_data
65
+
66
+
67
+ def _load_paths(data_dir, sort=True):
68
+ paths = []
69
+ img_data_dir = data_dir / "image_data"
70
+ for p in tqdm(
71
+ Path(img_data_dir).glob("*/*.json"),
72
+ desc=f"loading dataset paths from {str(data_dir)}",
73
+ ):
74
+ paths.append(p)
75
+ return sorted(paths)
76
+
77
+
78
+ class LazyLoader:
79
+ def __init__(self, data_dir):
80
+ self.paths = _load_paths(data_dir)
81
+
82
+ def __len__(self):
83
+ return len(self.paths)
84
+
85
+ def __getitem__(self, idx):
86
+ data = load_json(self.paths[idx])
87
+ if data is None:
88
+ return self[random.randint(0, len(self) - 1)]
89
+ return data
90
+
91
+
92
+ class ImgCptDataset(Dataset):
93
+ """
94
+ Dataset which loads image caption data from our standard format and transforms them into tensors that can be input to the model.
95
+ Images are expected to be stored in data_dir/images, image data in data_dir/image_data and each data item is a json file with format {"image_path": img_path, "captions": [caption1, caption2,...], "metadata":{...}}
96
+ """
97
+
98
+ def __init__(
99
+ self, data_dir, tokenizer, transforms, seq_len=2048, load_data_in_memory=False
100
+ ):
101
+ self.data_dir = Path(data_dir)
102
+ self.tokenizer = tokenizer
103
+ self.transforms = transforms
104
+ self.seq_len = seq_len
105
+ self.load_data_in_memory = load_data_in_memory
106
+ if self.load_data_in_memory:
107
+ self.data = _read_image_data(self.data_dir)
108
+ else:
109
+ self.data = LazyLoader(self.data_dir)
110
+
111
+ def __len__(self):
112
+ return len(self.data)
113
+
114
+ def __getitem__(
115
+ self, idx
116
+ ) -> Tuple[TensorType["b", "c", "h", "w"], TensorType["b", "s"]]:
117
+ img_data = self.data[idx]
118
+ try:
119
+ try:
120
+ img_path = self.data_dir / img_data["image_path"]
121
+ except KeyError as e:
122
+ # if no image path is found, assume path is same as .json, but .jpg
123
+ if not self.load_data_in_memory:
124
+ p = self.data.paths[idx]
125
+ img_path = (
126
+ self.data_dir
127
+ / "images"
128
+ / Path(p.parent).name
129
+ / Path(p.name).with_suffix(".jpg")
130
+ )
131
+ else:
132
+ raise e
133
+ img = Image.open(img_path)
134
+ img_tensor = self.transforms(img)
135
+ caption = random.choice(img_data["captions"])
136
+ caption_tensor = self.tokenizer.encode(
137
+ caption,
138
+ return_tensors="pt",
139
+ max_length=self.seq_len,
140
+ padding="max_length",
141
+ truncation=True,
142
+ )
143
+ return img_tensor, caption_tensor
144
+ except (
145
+ UnidentifiedImageError,
146
+ OSError,
147
+ DecompressionBombError,
148
+ IndexError,
149
+ ) as e:
150
+ # return random index if image is corrupt
151
+ print(f"Warning: Could not load image {str(img_path)}")
152
+ return self[random.randint(0, len(self) - 1)]
153
+
154
+
155
+ def collate_fn(batch_data: List[Tuple[torch.Tensor, torch.Tensor]], seq_len=2048):
156
+
157
+ all_images, all_captions = list(
158
+ zip(*batch_data)
159
+ ) # [(img1, caption1), (img2, caption2), ... ] -> [(img1, img2, ... ), (caption1, caption2, ... )]
160
+ return torch.cat(all_images), torch.cat([i[:, :seq_len] for i in all_captions])
magma/image_encoders.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Callable, Union
4
+ from torchtyping import patch_typeguard
5
+ from einops import rearrange
6
+ import timm
7
+ import clip
8
+ from functools import partial
9
+
10
+ # ----------------------------- Utils --------------------------------------
11
+
12
+ clip.model.LayerNorm = (
13
+ nn.LayerNorm
14
+ ) # we need to patch this for clip to work with deepspeed
15
+ patch_typeguard() # needed for torchtyping typechecks to work
16
+
17
+
18
+ class Lambda(torch.nn.Module):
19
+ def __init__(self, fn: Callable):
20
+ super().__init__()
21
+ assert hasattr(fn, "__call__")
22
+ self.fn = fn
23
+
24
+ def forward(self, x):
25
+ return self.fn(x)
26
+
27
+
28
+ # ------------------------- Image encoders ----------------------------------
29
+
30
+
31
+ def nfresnet50(
32
+ device: Union[torch.device, str] = None, pretrained: bool = True
33
+ ) -> nn.Module:
34
+ """
35
+ Loads nfresnet50 model, removing the pooling layer and replacing it with
36
+ an adaptive pooling layer.
37
+ """
38
+ encoder = torch.nn.Sequential(
39
+ *list(timm.create_model("nf_resnet50", pretrained=pretrained).children())[:-1]
40
+ )
41
+ pooling = torch.nn.AdaptiveAvgPool2d((1, 1))
42
+ encoder = torch.nn.Sequential(encoder, pooling)
43
+ if device is not None:
44
+ encoder = encoder.to(device)
45
+ return encoder
46
+
47
+
48
+ def clip_encoder(
49
+ device: Union[torch.device, str] = None, name: str = "clip",
50
+ ) -> nn.Module:
51
+ """
52
+ Loads clip's image encoder module, discarding the lm component.
53
+
54
+ If the variant is a resnet model, we also remove the attention pooling.
55
+ """
56
+ if name in ["clip", "ViT-B/32"]:
57
+ name = "ViT-B/32"
58
+ elif name in ["clip_resnet", "RN50x4"]:
59
+ name = "RN50x4"
60
+ elif name in ["clip_resnet_large", "RN50x16"]:
61
+ name = "RN50x16"
62
+ else:
63
+ raise ValueError(f"encoder {name} not recognized")
64
+
65
+ encoder = clip.load(name, device=device)[0].visual
66
+
67
+ if device is not None:
68
+ encoder = encoder.to(device)
69
+
70
+ if "RN" in name:
71
+ # remove attention pooling
72
+ encoder.attnpool = Lambda(
73
+ partial(rearrange, pattern="b d h w -> b (h w) d")
74
+ ) # remove attn pooling, just use reshaped features
75
+
76
+ return encoder
77
+
78
+
79
+ def get_image_encoder(
80
+ name: str, device: Union[torch.device, str] = None, pretrained: bool = False
81
+ ) -> torch.nn.Module:
82
+ """
83
+ Loads image encoder module
84
+ """
85
+ if name == "nfresnet50":
86
+ encoder = nfresnet50(device=device, pretrained=pretrained)
87
+ elif "clip" in name:
88
+ encoder = clip_encoder(device=device, name=name)
89
+ else:
90
+ raise ValueError(f"image encoder {name} not recognized")
91
+ return encoder
magma/image_input.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from io import BytesIO
3
+ import PIL.Image as PilImage
4
+ from typing import Callable
5
+
6
+ class ImageInput():
7
+ """Wrapper to handle image inputs both from local paths and urls
8
+ Args:
9
+ path_or_url (str): path or link to image.
10
+ """
11
+ def __init__(self, path_or_url):
12
+
13
+ self.path_or_url = path_or_url
14
+ if self.path_or_url.startswith("http://") or self.path_or_url.startswith("https://"):
15
+ try:
16
+ response = requests.get(path_or_url)
17
+ self.pil_image = PilImage.open(BytesIO(response.content))
18
+ except:
19
+ raise Exception(f'Could not retrieve image from url:\n{self.path_or_url}')
20
+ else:
21
+ self.pil_image = PilImage.open(path_or_url)
22
+
23
+ def get_transformed_image(self, transform_fn: Callable): ## to be called internally
24
+ return transform_fn(self.pil_image)
magma/image_prefix.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchtyping import TensorType
4
+ from einops import rearrange
5
+ from .image_encoders import get_image_encoder
6
+ from .config import MultimodalConfig
7
+
8
+ # ------------------------- Image prefix ----------------------------------
9
+
10
+ # for models that are fixed to a specific sequence lengths (i.e clip models with no pooling), the sequence lengths are below
11
+ ENCODER_SEQ_LENS = {
12
+ "clip_resnet": 49,
13
+ "clip_resnet_large": 144,
14
+ }
15
+
16
+ ENCODER_OUT_DIMS = {
17
+ "nfresnet50": 2048,
18
+ "clip": 512,
19
+ "clip_resnet": 2560,
20
+ "clip_resnet_large": 3072,
21
+ }
22
+
23
+
24
+ class ImagePrefix(nn.Module):
25
+
26
+ """
27
+ Takes in a batch of images and returns a batch of embeddings of the
28
+ same dimensions as the LM's word embeddings.
29
+
30
+ :param config: MultimodalConfig object
31
+ :param out_dim: output dimension of the embedding
32
+ :param device: device to run the model on
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ config: MultimodalConfig,
38
+ out_dim: int = 2048,
39
+ device=None,
40
+ ):
41
+ super().__init__()
42
+ self.device = device or torch.device(
43
+ "cuda" if torch.cuda.is_available() else "cpu"
44
+ )
45
+ self.config = config
46
+ self.encoder_type = config.encoder_name
47
+
48
+ # get image encoder backbone
49
+ self.enc = get_image_encoder(
50
+ config.encoder_name,
51
+ pretrained=config.pretrained_img_encoder,
52
+ )
53
+ self.encoder_out_dim = ENCODER_OUT_DIMS[
54
+ self.encoder_type
55
+ ] # out dim for image encoder
56
+
57
+ self.out_dim = out_dim # out dim for lm
58
+
59
+ # set the out seq len to that specified in the config, or for some models, the hardcoded value
60
+ self.out_seq_len = (
61
+ config.image_seq_len
62
+ if config.encoder_name not in ENCODER_SEQ_LENS
63
+ else ENCODER_SEQ_LENS[config.encoder_name]
64
+ )
65
+
66
+ # get the output projection
67
+ proj_out_dim = (
68
+ (self.out_dim * self.out_seq_len)
69
+ if self.encoder_type not in ENCODER_SEQ_LENS
70
+ else self.out_dim
71
+ )
72
+ self.proj = nn.Linear(self.encoder_out_dim, proj_out_dim)
73
+ self.dropout = nn.Dropout(config.image_embed_dropout_prob)
74
+ self.use_layernorm = config.use_image_embed_layernorm
75
+ if self.use_layernorm:
76
+ self.ln = nn.LayerNorm(self.out_dim)
77
+
78
+ def forward(
79
+ self, x: TensorType["b", "c", "h", "w"]
80
+ ) -> TensorType["b", "seq", "out_dim"]:
81
+
82
+ # pass through image encoder
83
+ logits = self.enc(x)
84
+
85
+ # remove trailing dimensions of size 1 + pass through linear
86
+ if logits.ndim == 4:
87
+ logits = rearrange(logits, "b d 1 1 -> b d")
88
+ elif logits.ndim == 3:
89
+ assert self.encoder_type in ENCODER_SEQ_LENS
90
+ else:
91
+ assert logits.ndim == 2
92
+
93
+ logits = self.proj(logits)
94
+
95
+ # reshape to desired output shape
96
+ if (
97
+ self.encoder_type not in ENCODER_SEQ_LENS
98
+ ): # don't need to reshape those with fixed seq lens / no pooling
99
+ logits = rearrange(
100
+ logits, "b (s d) -> b s d", d=self.out_dim, s=self.out_seq_len
101
+ )
102
+
103
+ # pass through dropout and layer norm
104
+ logits = self.dropout(logits)
105
+
106
+ if self.use_layernorm:
107
+ logits = self.ln(logits)
108
+
109
+ return logits
magma/language_model.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import GPTNeoForCausalLM, AutoConfig, GPT2LMHeadModel
3
+ from .utils import print_main
4
+ from pathlib import Path
5
+ from transformers.modeling_utils import no_init_weights
6
+
7
+ LANGUAGE_MODELS = [
8
+ "gptj",
9
+ ]
10
+
11
+
12
+ def gptj_config():
13
+ config = AutoConfig.from_pretrained("EleutherAI/gpt-neo-2.7B")
14
+ config.attention_layers = ["global"] * 28
15
+ config.attention_types = [["global"], 28]
16
+ config.num_layers = 28
17
+ config.num_heads = 16
18
+ config.hidden_size = 256 * config.num_heads
19
+ config.vocab_size = 50400
20
+ config.rotary = True
21
+ config.rotary_dim = 64
22
+ config.jax = True
23
+ config.gradient_checkpointing = True
24
+ return config
25
+
26
+
27
+ def get_gptj(
28
+ gradient_checkpointing: bool = True,
29
+ from_pretrained=False,
30
+ ) -> torch.nn.Module:
31
+ """
32
+ Loads GPTJ language model from HF
33
+ """
34
+ print_main("Loading GPTJ language model...")
35
+ config = gptj_config()
36
+ config.gradient_checkpointing = gradient_checkpointing
37
+ if gradient_checkpointing:
38
+ config.use_cache = False
39
+ config.model_device = "cpu"
40
+ if from_pretrained:
41
+ raise NotImplemented("GPTJ pretrained not implemented")
42
+ else:
43
+ with no_init_weights():
44
+ model = GPTNeoForCausalLM(config=config)
45
+ return model
magma/magma.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from os.path import exists
3
+ import torch
4
+ import torch.nn as nn
5
+ from copy import deepcopy
6
+ from typing import Literal, Optional, List
7
+ from torchtyping import TensorType
8
+ from transformers.file_utils import ModelOutput
9
+ from magma.config import MultimodalConfig
10
+
11
+ from magma.utils import get_tokenizer
12
+ from .language_model import get_gptj
13
+ from .adapters import (
14
+ Adapter,
15
+ ParallelAdapter,
16
+ AdapterWrapper,
17
+ ParallelAdapterWrapper,
18
+ )
19
+ from .image_prefix import ImagePrefix
20
+ from .sampling import generate
21
+ from .utils import build_labels, is_url, print_main, download_checkpoint
22
+ from .image_input import ImageInput
23
+ from .transforms import get_transforms
24
+
25
+ # ------------------------- Magma main class ----------------------------------
26
+
27
+
28
+ class Magma(nn.Module):
29
+ def __init__(self, config, device=None):
30
+ super().__init__()
31
+
32
+ if isinstance(config, (str, Path)):
33
+ config = MultimodalConfig.from_yml(
34
+ config
35
+ ) # load config from yml file if config is a string
36
+ else:
37
+ assert isinstance(config, MultimodalConfig)
38
+
39
+ self.device = device or torch.device(
40
+ "cuda" if torch.cuda.is_available() else "cpu"
41
+ )
42
+ self.config = config
43
+ self.lm = get_gptj().to(self.device)
44
+ self.seq_len = self.lm.config.max_position_embeddings
45
+
46
+ self.tokenizer = get_tokenizer("gpt2", sequence_length=self.seq_len)
47
+
48
+ self.image_token = self.tokenizer.cls_token_id
49
+ self.eos_token = self.tokenizer.eos_token_id
50
+ self.lm.resize_token_embeddings(len(self.tokenizer))
51
+ self.lm.config.pad_token_id = self.tokenizer.eos_token_id
52
+ self.word_embedding = self.lm.transformer.wte.to(device)
53
+ self.transformer = self.lm.transformer.h
54
+
55
+ # adapter settings
56
+ self.mlp_adapter_added, self.attn_adapter_added = False, False
57
+
58
+ self.image_prefix = ImagePrefix(
59
+ config=config,
60
+ out_dim=self.lm.config.hidden_size,
61
+ ).to(self.device)
62
+
63
+ # might change based on the type of image encoder, so get from prefix instead of config
64
+ self.image_prefix_seq_len = self.image_prefix.out_seq_len
65
+
66
+ self.transforms = get_transforms(
67
+ config.image_size,
68
+ config.encoder_name,
69
+ input_resolution=self.image_prefix.enc.input_resolution,
70
+ )
71
+
72
+ # add adapters
73
+ if config.adapter_config:
74
+ mlp_config = deepcopy(config.adapter_config.get("mlp", None))
75
+ if mlp_config:
76
+ assert mlp_config.get("adapter_type") is not None
77
+ self.add_adapters(
78
+ location="mlp",
79
+ adapter_type=mlp_config.pop("adapter_type"),
80
+ downsample_factor=mlp_config.pop("downsample_factor", 4),
81
+ **mlp_config,
82
+ )
83
+ attn_config = deepcopy(config.adapter_config.get("attention", None))
84
+ if attn_config:
85
+ assert attn_config.get("adapter_type") is not None
86
+ self.add_adapters(
87
+ location="attention",
88
+ adapter_type=attn_config.pop("adapter_type"),
89
+ **attn_config,
90
+ )
91
+
92
+ # freeze parameters
93
+ if config.freeze_lm:
94
+ for name, param in self.lm.named_parameters(): # freeze lm weights
95
+ if config.adapter_config and "adapter" in name:
96
+ param.requires_grad = True
97
+
98
+ if config.freeze_img_encoder:
99
+ for param in self.image_prefix.enc.parameters():
100
+ param.requires_grad = False
101
+
102
+ def add_adapters(
103
+ self,
104
+ downsample_factor: int = 4,
105
+ adapter_type: Literal["normal", "parallel", "scaled_parallel"] = "normal",
106
+ location: Literal["mlp", "attention"] = "mlp",
107
+ ff_attr: str = "mlp",
108
+ attn_attr: str = "attn",
109
+ **adapter_kwargs,
110
+ ):
111
+ """
112
+ Adds an adapter layer to `self` at the specified location
113
+ """
114
+ assert adapter_type in [
115
+ "normal",
116
+ "parallel",
117
+ "scaled_parallel",
118
+ ], "adapter_type must be one of 'normal', 'parallel', or 'scaled_parallel'"
119
+ assert location in [
120
+ "mlp",
121
+ "attention",
122
+ ], "location must be one of 'mlp' or 'attention'"
123
+
124
+ for l in range(len(self.transformer)):
125
+ if location == "mlp":
126
+ if self.mlp_adapter_added:
127
+ raise ValueError("Adapter layer already added")
128
+ mlp = getattr(self.transformer[l], ff_attr)
129
+ if adapter_type in ["parallel", "scaled_parallel"]:
130
+ adapter_layer = ParallelAdapter(
131
+ module=mlp,
132
+ dim=self.lm.config.hidden_size,
133
+ downsample_factor=downsample_factor,
134
+ scaled=adapter_type == "scaled_parallel",
135
+ **adapter_kwargs,
136
+ )
137
+ else:
138
+ adpt = Adapter(
139
+ dim=self.lm.config.hidden_size,
140
+ downsample_factor=downsample_factor,
141
+ **adapter_kwargs,
142
+ )
143
+ adapter_layer = nn.Sequential(
144
+ *[
145
+ mlp,
146
+ adpt,
147
+ ]
148
+ )
149
+ setattr(self.transformer[l], ff_attr, adapter_layer)
150
+ else:
151
+ if self.attn_adapter_added:
152
+ raise ValueError("Adapter layer already added")
153
+ attn = getattr(self.transformer[l], attn_attr)
154
+ if adapter_type in ["parallel", "scaled_parallel"]:
155
+ adapter_layer = ParallelAdapterWrapper(
156
+ module=attn,
157
+ dim=self.lm.config.hidden_size,
158
+ downsample_factor=downsample_factor,
159
+ scaled="scaled" in adapter_type,
160
+ **adapter_kwargs,
161
+ )
162
+ else:
163
+ adapter_layer = AdapterWrapper(
164
+ attn_block=attn,
165
+ dim=self.lm.config.hidden_size,
166
+ downsample_factor=downsample_factor,
167
+ **adapter_kwargs,
168
+ )
169
+ setattr(self.transformer[l], attn_attr, adapter_layer)
170
+
171
+ if location == "mlp":
172
+ self.mlp_adapter_added = True
173
+ else:
174
+ self.attn_adapter_added = True
175
+
176
+ def preprocess_inputs(self, input_list: list, embed = True) -> List[torch.Tensor]:
177
+ """
178
+ Expects a list of strings and instances of ImageInput
179
+ Converts them into a list of tensors and then optionally runs self.embed over it
180
+ """
181
+ for i in range(len(input_list)):
182
+ inp = input_list[i]
183
+ if isinstance(inp, str):
184
+ input_list[i] = self.tokenizer.encode(inp, return_tensors="pt")
185
+ elif isinstance(inp, ImageInput):
186
+ input_list[i] = inp.get_transformed_image(transform_fn = self.transforms)
187
+ else:
188
+ raise Exception(f'Invalid input type:{type(inp)}')
189
+
190
+ if embed == True:
191
+ return self.embed(input_list)
192
+ else:
193
+ return input_list
194
+
195
+ def embed(self, inputs: List[torch.Tensor]) -> TensorType["b", "s", "d"]:
196
+ """
197
+ Embeds a list of tensors In the correct format to input into the LM (b, s, d).
198
+ For each tensor, if it's 2d assume it's text and use word embedding,
199
+ if it's 4d, assume it's an image, and use image_prefix to embed.
200
+ """
201
+ emb_list = []
202
+ for x in inputs:
203
+ if x.ndim == 2:
204
+ x = x.to(self.device)
205
+ emb_list.append(self.word_embedding(x))
206
+ elif x.ndim == 4:
207
+ x = x.to(self.device).half()
208
+ image_embeddings = self.image_prefix(x)
209
+ emb_list.append(image_embeddings)
210
+ else:
211
+ raise ValueError(f"Expected 2d or 4d tensor, got {x.ndim}d")
212
+ return torch.cat(emb_list, dim=1)
213
+
214
+ @torch.no_grad()
215
+ def generate(
216
+ self,
217
+ embeddings: TensorType["b", "s", "d"],
218
+ max_steps: int = 100,
219
+ temperature: float = 0.7,
220
+ top_k: int = 0,
221
+ top_p: float = 0.9,
222
+ decode: bool = True,
223
+ ):
224
+ """
225
+ Generates captions for a batch of embeddings.
226
+ """
227
+
228
+ return generate(
229
+ self,
230
+ embeddings=embeddings,
231
+ max_steps=max_steps,
232
+ temperature=temperature,
233
+ top_k=top_k,
234
+ top_p=top_p,
235
+ decode=decode,
236
+ )
237
+
238
+ def forward(
239
+ self,
240
+ images: TensorType["b", "c", "h", "w"] = None,
241
+ captions: Optional[TensorType["b", "seq"]] = None,
242
+ output_hidden_states: bool = False,
243
+ input_embeddings: TensorType["b", "s", "d"] = None,
244
+ ) -> ModelOutput:
245
+ assert captions is not None, "Must provide captions in training"
246
+ assert any([i is not None for i in [images, input_embeddings]]) and not all(
247
+ [i is not None for i in [images, input_embeddings]]
248
+ ), "Pass in either images, or input embeddings, not both."
249
+ assert (
250
+ captions.shape[1] == self.seq_len
251
+ ), f"in training, captions should be padded to sequence length ({self.seq_len}), but are length {captions.shape[1]}"
252
+
253
+ if input_embeddings is None:
254
+ input_embeddings = self.image_prefix(images)
255
+ labels = build_labels(
256
+ input_embeddings, captions, self.eos_token, self.device
257
+ ) # build labels from input_embeddings
258
+ word_embeddings = self.word_embedding(captions)
259
+
260
+ # join together
261
+ input_embeddings = torch.cat(
262
+ (
263
+ input_embeddings,
264
+ word_embeddings[:, : -input_embeddings.shape[1], :],
265
+ ), # remove padding in the word embedding before concatenating
266
+ dim=1,
267
+ )
268
+
269
+ # forward joined embeddings through lm
270
+ lm_outputs = self.lm(
271
+ inputs_embeds=input_embeddings,
272
+ labels=labels,
273
+ output_hidden_states=output_hidden_states,
274
+ )
275
+
276
+ return lm_outputs
277
+
278
+ @classmethod
279
+ def from_checkpoint(cls, config_path, checkpoint_path, device = 'cpu'):
280
+ """
281
+ Loads a model checkpoint from disk / downlods from url if not present
282
+ """
283
+
284
+ checkpoint_url = 'https://drive.google.com/u/0/uc?id=1EiAY3IcKWmGADaLDzdG25ykQghUwza6L&export=download'
285
+
286
+ if exists(checkpoint_path) == False:
287
+ print_main(f'checkpoint: {checkpoint_path} does not exist, downloading model')
288
+ download_checkpoint(checkpoint_url = checkpoint_url, save_as = checkpoint_path)
289
+
290
+ model = cls(config = config_path)
291
+
292
+ sd = torch.load(checkpoint_path, map_location=torch.device("cpu"))
293
+ if "module" in sd.keys():
294
+ sd = sd["module"]
295
+
296
+ print_main('loading checkpoint magma')
297
+ model.load_state_dict(sd, strict=False)
298
+ print_main("magma model successfully loaded")
299
+
300
+ model.half().to(device)
301
+ return model
magma/sampling.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torchtyping import TensorType
4
+ from typing import Union, List
5
+
6
+
7
+ def top_p_filter(logits: TensorType[..., "vocab"], threshold: float = 0.9):
8
+ """
9
+ Nucleus sampling
10
+ """
11
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
12
+ cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
13
+
14
+ sorted_indices_to_remove = cum_probs > (1 - threshold)
15
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
16
+ sorted_indices_to_remove[..., 0] = 0
17
+
18
+ sorted_logits[sorted_indices_to_remove] = float("-inf")
19
+ return sorted_logits.scatter(1, sorted_indices, sorted_logits)
20
+
21
+
22
+ def top_k_filter(logits, k):
23
+ """
24
+ Top K sampling
25
+ """
26
+ assert k > 0
27
+ val, ind = torch.topk(logits, k)
28
+ probs = torch.full_like(logits, float("-inf"))
29
+ probs.scatter_(1, ind, val)
30
+ return probs
31
+
32
+
33
+ def remove_tokens_after_eos(tensor, eos_token, image_token):
34
+ # any tokens after and end of sequence token is produced are also set to the eos token, and removed
35
+ eos_index = (tensor == eos_token).nonzero()
36
+ if eos_index.any():
37
+ tensor[eos_index[0] :] = eos_token
38
+
39
+ tensor = tensor.tolist()
40
+ return [i for i in tensor if (not i == image_token) and (not i == eos_token)]
41
+
42
+
43
+ @torch.no_grad()
44
+ def generate(
45
+ model: "Magma",
46
+ embeddings: TensorType["b", "s", "d"],
47
+ max_steps: int = 100,
48
+ temperature: float = 0.7,
49
+ top_k: int = 0,
50
+ top_p: float = 0.9,
51
+ eos_token: int = None,
52
+ decode: bool = True,
53
+ ) -> Union[List[str], TensorType["b", "s"]]:
54
+ """
55
+ Generates captions for a batch of embeddings.
56
+
57
+ :param model: The model to use for generation.
58
+ :param embeddings: The embeddings to generate captions for.
59
+ :param max_steps: The maximum number of steps to generate captions for.
60
+ :param temperature: The temperature to use for sampling.
61
+ :param top_k: value for top k sampling. If 0, no sampling will be used.
62
+ :param top_p: value for top p sampling. If 0, no sampling will be used.
63
+ :param eos_token: The token to use for end of sequence.
64
+ :param decode: Whether to decode the output into text, or return the raw tokens.
65
+ """
66
+
67
+ # init values
68
+ eos_token = eos_token or model.eos_token
69
+ was_training = model.training
70
+ model.eval()
71
+ b, s, _ = embeddings.shape
72
+ past_key_values = None
73
+
74
+ # init output with image tokens
75
+ out = torch.zeros((b, s), dtype=torch.long).to(model.device) + model.image_token
76
+
77
+ # do sampling
78
+ for i in range(max_steps):
79
+ if i == 0:
80
+ # initial input
81
+ outputs = model.lm(
82
+ inputs_embeds=embeddings,
83
+ use_cache=True,
84
+ past_key_values=past_key_values,
85
+ )
86
+ else:
87
+ # now caching past k/v so we can use only the last token
88
+ outputs = model.lm(
89
+ input_ids=out[:, -1:], use_cache=True, past_key_values=past_key_values
90
+ )
91
+
92
+ logits = outputs.logits[:, -1, :].float()
93
+ past_key_values = outputs.past_key_values
94
+
95
+ # filter / temperature sample
96
+ if temperature == 0.0:
97
+ next_token = torch.argmax(logits, dim=-1)
98
+ else:
99
+ if top_k > 0:
100
+ logits = top_k_filter(logits, k=top_k)
101
+ if top_p > 0:
102
+ logits = top_p_filter(logits, threshold=top_p)
103
+
104
+ probs = F.softmax(logits / temperature, dim=-1)
105
+ next_token = torch.multinomial(probs, num_samples=1)
106
+
107
+ out = torch.cat((out, next_token), dim=-1)
108
+
109
+ if eos_token is not None and (next_token == eos_token).all():
110
+ break
111
+
112
+ if decode:
113
+ captions = []
114
+ for b in out:
115
+ b = remove_tokens_after_eos(b, eos_token, model.image_token)
116
+ caption = model.tokenizer.decode(b)
117
+ captions.append(caption)
118
+ out = captions
119
+
120
+ model.train(was_training)
121
+ return out
magma/train_loop.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+ from .utils import reduce_losses, to_cuda_half
4
+ from torchvision.utils import make_grid
5
+
6
+
7
+ def train_step(config, train_loader, model_engine):
8
+ losses = []
9
+
10
+ for _ in range(config.gradient_accumulation_steps):
11
+ images, captions = next(train_loader)
12
+ images, captions = images.half().cuda(), captions.cuda()
13
+ if config.run_blind:
14
+ images = torch.zeros_like(images)
15
+ outputs = model_engine(images, captions)
16
+ loss = outputs.loss
17
+ losses.append(loss)
18
+ model_engine.backward(loss)
19
+ model_engine.step()
20
+
21
+ return reduce_losses(torch.mean(torch.stack(losses))).item()
22
+
23
+
24
+ def train_step_classification(config, train_loader, model_engine, return_accuracy=True):
25
+ losses = []
26
+ if return_accuracy:
27
+ accuracies = []
28
+ for _ in range(config.gradient_accumulation_steps):
29
+ images, captions, class_labels = next(train_loader)
30
+ images, captions, class_labels = to_cuda_half(images, captions, class_labels)
31
+ if config.run_blind:
32
+ images = torch.zeros_like(images)
33
+ loss, logits = model_engine(images, captions, class_labels)
34
+ losses.append(loss)
35
+ if return_accuracy:
36
+ argmax_pred = logits.argmax(dim=-1)
37
+ accuracies.append((argmax_pred == class_labels).float().mean())
38
+ model_engine.backward(loss)
39
+ model_engine.step()
40
+
41
+ loss_reduced = reduce_losses(torch.mean(torch.stack(losses))).item()
42
+ if return_accuracy:
43
+ accuracy_reduced = reduce_losses(torch.mean(torch.stack(accuracies))).item()
44
+ return loss_reduced, accuracy_reduced
45
+ return loss_reduced
46
+
47
+
48
+ def eval_step(config, eval_loader, model_engine):
49
+ losses = []
50
+
51
+ for i in tqdm(range(config.eval_steps), "evaluating..."):
52
+ images, captions = next(eval_loader)
53
+ images, captions = images.half().cuda(), captions.cuda()
54
+ if config.run_blind:
55
+ images = torch.zeros_like(images)
56
+ outputs = model_engine(images, captions)
57
+ loss = outputs.loss
58
+ losses.append(loss)
59
+
60
+ return reduce_losses(torch.mean(torch.stack(losses))).item()
61
+
62
+
63
+ def eval_step_classification(config, train_loader, model_engine, return_accuracy=True):
64
+ losses = []
65
+ if return_accuracy:
66
+ accuracies = []
67
+ for _ in range(config.gradient_accumulation_steps):
68
+ images, captions, class_labels = next(train_loader)
69
+ images, captions, class_labels = to_cuda_half(images, captions, class_labels)
70
+ if config.run_blind:
71
+ images = torch.zeros_like(images)
72
+ loss, logits = model_engine(images, captions, class_labels)
73
+ losses.append(loss)
74
+ if return_accuracy:
75
+ argmax_pred = logits.argmax(dim=-1)
76
+ accuracies.append((argmax_pred == class_labels).float().mean())
77
+
78
+ loss_reduced = reduce_losses(torch.mean(torch.stack(losses))).item()
79
+ if return_accuracy:
80
+ accuracy_reduced = reduce_losses(torch.mean(torch.stack(accuracies))).item()
81
+ return loss_reduced, accuracy_reduced
82
+ return loss_reduced
83
+
84
+
85
+ def inference_step(config, eval_loader, model_engine):
86
+ images, _ = next(eval_loader)
87
+ images = images.half().cuda()
88
+ if config.run_blind:
89
+ images = torch.zeros_like(images)
90
+ captions = model_engine(
91
+ images, captions=None, inference=True
92
+ ) # [caption1, caption2, ... b]
93
+ width = min(2, images.shape[0])
94
+ image_grid = make_grid(images[:width])
95
+ caption = ""
96
+ for i in range(width):
97
+ caption += f"Caption {i}: \n{captions[i]}\n"
98
+ return image_grid, caption
magma/transforms.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms as T
2
+ import torch.nn.functional as F
3
+ from PIL import ImageOps
4
+ import PIL
5
+ import random
6
+
7
+
8
+ def pad_to_size(x, size=256):
9
+ delta_w = size - x.size[0]
10
+ delta_h = size - x.size[1]
11
+ padding = (
12
+ delta_w // 2,
13
+ delta_h // 2,
14
+ delta_w - (delta_w // 2),
15
+ delta_h - (delta_h // 2),
16
+ )
17
+ new_im = ImageOps.expand(x, padding)
18
+ return new_im
19
+
20
+
21
+ def pad_to_size_tensor(x, size=256):
22
+ offset_dim_1 = size - x.shape[1]
23
+ offset_dim_2 = size - x.shape[2]
24
+
25
+ padding_dim_1 = max(offset_dim_1 // 2, 0)
26
+ padding_dim_2 = max(offset_dim_2 // 2, 0)
27
+
28
+ if offset_dim_1 % 2 == 0:
29
+ pad_tuple_1 = (padding_dim_1, padding_dim_1)
30
+ else:
31
+ pad_tuple_1 = (padding_dim_1 + 1, padding_dim_1)
32
+
33
+ if offset_dim_2 % 2 == 0:
34
+ pad_tuple_2 = (padding_dim_2, padding_dim_2)
35
+ else:
36
+ pad_tuple_2 = (padding_dim_2 + 1, padding_dim_2)
37
+
38
+ padded = F.pad(x, pad=(*pad_tuple_2, *pad_tuple_1, 0, 0))
39
+ return padded
40
+
41
+
42
+ class RandCropResize(object):
43
+
44
+ """
45
+ Randomly crops, then randomly resizes, then randomly crops again, an image. Mirroring the augmentations from https://arxiv.org/abs/2102.12092
46
+ """
47
+
48
+ def __init__(self, target_size):
49
+ self.target_size = target_size
50
+
51
+ def __call__(self, img):
52
+ img = pad_to_size(img, self.target_size)
53
+ d_min = min(img.size)
54
+ img = T.RandomCrop(size=d_min)(img)
55
+ t_min = min(d_min, round(9 / 8 * self.target_size))
56
+ t_max = min(d_min, round(12 / 8 * self.target_size))
57
+ t = random.randint(t_min, t_max + 1)
58
+ img = T.Resize(t)(img)
59
+ if min(img.size) < 256:
60
+ img = T.Resize(256)(img)
61
+ return T.RandomCrop(size=self.target_size)(img)
62
+
63
+
64
+ def get_transforms(
65
+ image_size, encoder_name, input_resolution=None, use_extra_transforms=False
66
+ ):
67
+ if "clip" in encoder_name:
68
+ assert input_resolution is not None
69
+ return clip_preprocess(input_resolution)
70
+
71
+ base_transforms = [
72
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
73
+ RandCropResize(image_size),
74
+ T.RandomHorizontalFlip(p=0.5),
75
+ ]
76
+ if use_extra_transforms:
77
+ extra_transforms = [T.ColorJitter(0.1, 0.1, 0.1, 0.05)]
78
+ base_transforms += extra_transforms
79
+ base_transforms += [
80
+ T.ToTensor(),
81
+ maybe_add_batch_dim,
82
+ ]
83
+ base_transforms = T.Compose(base_transforms)
84
+ return base_transforms
85
+
86
+
87
+ def maybe_add_batch_dim(t):
88
+ if t.ndim == 3:
89
+ return t.unsqueeze(0)
90
+ else:
91
+ return t
92
+
93
+
94
+ def pad_img(desired_size):
95
+ def fn(im):
96
+ old_size = im.size # old_size[0] is in (width, height) format
97
+
98
+ ratio = float(desired_size) / max(old_size)
99
+ new_size = tuple([int(x * ratio) for x in old_size])
100
+
101
+ im = im.resize(new_size, PIL.Image.ANTIALIAS)
102
+ # create a new image and paste the resized on it
103
+
104
+ new_im = PIL.Image.new("RGB", (desired_size, desired_size))
105
+ new_im.paste(
106
+ im, ((desired_size - new_size[0]) // 2, (desired_size - new_size[1]) // 2)
107
+ )
108
+
109
+ return new_im
110
+
111
+ return fn
112
+
113
+
114
+ def crop_or_pad(n_px, pad=False):
115
+ if pad:
116
+ return pad_img(n_px)
117
+ else:
118
+ return T.CenterCrop(n_px)
119
+
120
+
121
+ def clip_preprocess(n_px, use_pad=False):
122
+ return T.Compose(
123
+ [
124
+ T.Resize(n_px, interpolation=T.InterpolationMode.BICUBIC),
125
+ crop_or_pad(n_px, pad=use_pad),
126
+ lambda image: image.convert("RGB"),
127
+ T.ToTensor(),
128
+ maybe_add_batch_dim,
129
+ T.Normalize(
130
+ (0.48145466, 0.4578275, 0.40821073),
131
+ (0.26862954, 0.26130258, 0.27577711),
132
+ ),
133
+ ]
134
+ )
magma/utils.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch.distributed as dist
3
+ from transformers import GPT2TokenizerFast
4
+ import deepspeed
5
+ from pathlib import Path
6
+ import wandb
7
+ import os
8
+ import yaml
9
+ import torch
10
+ from collections import defaultdict
11
+ from torchtyping import TensorType
12
+ import gdown
13
+
14
+
15
+ def is_main():
16
+ if dist.is_initialized():
17
+ return dist.get_rank() == 0
18
+ return True
19
+
20
+
21
+ def print_main(*msg):
22
+ if is_main():
23
+ print(*msg)
24
+
25
+
26
+ def reduce_losses(losses):
27
+ """Reduce a tensor of losses across all GPUs."""
28
+ if dist.is_initialized():
29
+ losses = losses.detach().clone()
30
+ # We use `all_reduce` because it is better supported than `reduce`
31
+ dist.all_reduce(losses, dist.ReduceOp.SUM)
32
+ return losses / dist.get_world_size()
33
+ else:
34
+ return losses
35
+
36
+
37
+ def cycle(loader):
38
+ while True:
39
+ for data in loader:
40
+ yield data
41
+
42
+
43
+ def get_tokenizer(name="gpt2", sequence_length=2048):
44
+ """
45
+ Gets tokenizer for LM
46
+ """
47
+ if name == "gpt2":
48
+ tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
49
+ tokenizer.pad_token_id = tokenizer.eos_token
50
+ tokenizer.padding_side = "right"
51
+ tokenizer.model_max_length = sequence_length
52
+ # setup lm settings
53
+ tokenizer.add_special_tokens(
54
+ {"cls_token": "<|image|>"}
55
+ ) # add special image token to tokenizer
56
+ else:
57
+ raise ValueError(f"Tokenizer {name} not recognized")
58
+ return tokenizer
59
+
60
+
61
+ def parse_args():
62
+ parser = argparse.ArgumentParser()
63
+ parser.add_argument(
64
+ "--config", type=str, required=False, help="path to your training config"
65
+ )
66
+ parser.add_argument(
67
+ "--local_rank",
68
+ type=int,
69
+ default=-1,
70
+ help="local rank passed from distributed launcher",
71
+ )
72
+ deepspeed.add_config_arguments(parser)
73
+
74
+ args = parser.parse_args()
75
+ args.deepspeed = True
76
+ return args
77
+
78
+
79
+ def wandb_log(*args, **kwargs):
80
+ if is_main():
81
+ wandb.log(*args, **kwargs)
82
+
83
+
84
+ def wandb_init(*args, **kwargs):
85
+ if is_main():
86
+ wandb.init(*args, **kwargs)
87
+
88
+
89
+ def save_model(model_engine, save_dir, global_step, config=None):
90
+ os.makedirs(save_dir, exist_ok=True)
91
+ if config is not None:
92
+ config = config.to_dict()
93
+ with open(str(Path(save_dir) / "config.yml"), "w") as f:
94
+ yaml.dump(config, f, default_flow_style=False)
95
+ sd = {"global_step": global_step, "config": config}
96
+ model_engine.save_checkpoint(save_dir, client_state=sd)
97
+
98
+
99
+ def load_model(
100
+ model_engine, load_dir, load_optimizer_states=True, load_lr_scheduler_states=True
101
+ ):
102
+ """
103
+ Loads a model from disk and returns the global step to resume from if loading was successful, otherwise returns 0
104
+ """
105
+ try:
106
+ load_path, sd = model_engine.load_checkpoint(
107
+ load_dir,
108
+ load_optimizer_states=load_optimizer_states,
109
+ load_lr_scheduler_states=load_lr_scheduler_states,
110
+ )
111
+ except AssertionError as e:
112
+ load_path = None
113
+ print(e)
114
+ if load_path is None:
115
+ print("Model loading failed - starting from global step 0")
116
+ return 0
117
+ return sd["global_step"]
118
+
119
+
120
+ def get_params_for_weight_decay_optimization(module, config):
121
+ """
122
+ Divide params into with-weight-decay and without-weight-decay groups.
123
+ Layernorms and biases will have no weight decay but the rest will.
124
+ """
125
+ weight_decay_params = {"params": []}
126
+ no_weight_decay_params = {"params": [], "weight_decay": 0.0}
127
+ blacklist_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
128
+
129
+ for module_ in module.modules():
130
+ if isinstance(module_, blacklist_modules) or (
131
+ config.weight_decay == 0.0
132
+ ): # also include all parameters here if no weight decay is being done
133
+ no_weight_decay_params["params"].extend(
134
+ [
135
+ p
136
+ for p in list(module_._parameters.values())
137
+ if (p is not None) and p.requires_grad
138
+ ]
139
+ )
140
+ else:
141
+ for n, p in list(module_._parameters.items()):
142
+ if p is not None and p.requires_grad:
143
+ if n != "bias":
144
+ weight_decay_params["params"].append(p)
145
+ else:
146
+ no_weight_decay_params["params"].append(p)
147
+
148
+ param_dict = {
149
+ pn: p
150
+ for pn, p in module.named_parameters()
151
+ if p is not None and p.requires_grad
152
+ }
153
+ assert len(no_weight_decay_params["params"]) + len(
154
+ weight_decay_params["params"]
155
+ ) == len(
156
+ param_dict.keys()
157
+ ), "Number of params in both groups != total number of trainable params"
158
+ if config.weight_decay == 0.0:
159
+ # only return a single param group if no weight decay is being used anyway
160
+ return [no_weight_decay_params]
161
+ return [weight_decay_params, no_weight_decay_params]
162
+
163
+
164
+ def configure_param_groups(model, config):
165
+ """
166
+ Configures the different parameter groups in the model for training.
167
+ If a separate learning rate for the image prefix is provided, we separate out the groups here.
168
+ Additionally, parameters to which weight decay shouldn't be applied (layernorms / biases) are separated.
169
+ """
170
+ if config.image_enc_lr is not None:
171
+
172
+ # get the params for the image prefix / proj
173
+ image_enc_params = get_params_for_weight_decay_optimization(
174
+ model.image_prefix.enc, config
175
+ )
176
+ for pdict in image_enc_params:
177
+ pdict["lr"] = config.image_enc_lr
178
+ image_proj_params = get_params_for_weight_decay_optimization(
179
+ model.image_prefix.proj, config
180
+ )
181
+
182
+ # get the params for layernorm if it exists
183
+ if config.use_image_embed_layernorm:
184
+ image_ln_params = get_params_for_weight_decay_optimization(
185
+ model.image_prefix.ln, config
186
+ )
187
+ image_proj_params += image_ln_params
188
+
189
+ # get the params for the lm
190
+ lm_params = get_params_for_weight_decay_optimization(model.lm, config)
191
+
192
+ # get params for class head if it exists
193
+ class_params = []
194
+ if hasattr(model, "class_head") and model.class_head is not None:
195
+ class_params = get_params_for_weight_decay_optimization(
196
+ model.class_head, config
197
+ )
198
+
199
+ all_params = []
200
+ for p in image_enc_params + lm_params + image_proj_params + class_params:
201
+ if p["params"]:
202
+ all_params.append(p)
203
+ else:
204
+ all_params = get_params_for_weight_decay_optimization(model, config)
205
+
206
+ # merge param dicts with shared lr / wd values
207
+ d = defaultdict(dict)
208
+ for param_group in all_params:
209
+ lr = param_group.get("lr", None)
210
+ wd = param_group.get("weight_decay", None)
211
+ key = f"lr_{lr}_wd_{wd}"
212
+ if d[key].get("params") is None:
213
+ d[key]["params"] = []
214
+ d[key]["params"].extend(param_group["params"])
215
+ if lr is not None:
216
+ d[key]["lr"] = lr
217
+ if wd is not None:
218
+ d[key]["weight_decay"] = wd
219
+ all_params = list(d.values())
220
+
221
+ n_params = sum([len(d["params"]) for d in all_params])
222
+ param_dict = {
223
+ pn: p for pn, p in model.named_parameters() if p is not None and p.requires_grad
224
+ }
225
+ assert n_params == len(
226
+ param_dict
227
+ ), f"Some parameters are missing from param groups ({n_params} | {len(param_dict)})"
228
+
229
+ # if we're using multiple param groups, set the min / max lr for each one[]
230
+ # appropriately in deepspeed's scheduler
231
+ config.deepspeed_config_params["scheduler"]["params"]["warmup_min_lr"] = [
232
+ config.min_lr for _ in all_params
233
+ ]
234
+ config.deepspeed_config_params["scheduler"]["params"]["warmup_max_lr"] = [
235
+ d.get("lr", config.lr) for d in all_params
236
+ ]
237
+
238
+ return all_params
239
+
240
+
241
+ def count_parameters(model):
242
+ """
243
+ Counts the number of trainable parameters in a model
244
+ """
245
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
246
+
247
+
248
+ def log_table(name, model_outputs, gt_answers_list, global_step):
249
+ results_table = wandb.Table(columns=["model output", "ground truth(s)"])
250
+ for o, gt in zip(model_outputs, gt_answers_list):
251
+ results_table.add_data(o, gt)
252
+ wandb_log({f"eval/{name}": results_table}, step=global_step)
253
+
254
+
255
+ def get_world_info():
256
+ local_rank = int(os.environ["LOCAL_RANK"])
257
+ rank = int(os.environ["RANK"])
258
+ world_size = int(os.environ["WORLD_SIZE"])
259
+ return local_rank, rank, world_size
260
+
261
+
262
+ def init_distributed(backend="nccl"):
263
+ if not torch.distributed.is_initialized():
264
+ deepspeed.init_distributed(
265
+ dist_backend=backend, verbose=True, auto_mpi_discovery=True
266
+ )
267
+ local_rank, rank, world_size = get_world_info()
268
+ torch.cuda.set_device(local_rank)
269
+ return local_rank, rank, world_size
270
+
271
+
272
+ def collate_fn_classification(batch_data, seq_len=2048):
273
+
274
+ # for nvlr2: list(zip*(batch_data)) = [l_images, r_images, captions, class_labels]
275
+ image_list = list(zip(*batch_data))[:-2]
276
+ captions, class_labels = list(zip(*batch_data))[-2:]
277
+
278
+ # images, captions, class_labels = list(zip(*batch_data))
279
+ images_list = [torch.cat(image) for image in image_list]
280
+ captions = torch.cat([i[:, :seq_len] for i in captions])
281
+ class_labels = torch.stack(class_labels)
282
+ return images_list, captions, class_labels
283
+
284
+
285
+ def infer_checkpoint_path_from_config(config):
286
+ checkpoint_folder = config.save
287
+ if checkpoint_folder is None:
288
+ raise ValueError(
289
+ "No checkpoint folder specified in config. Please provide a checkpoint."
290
+ )
291
+
292
+ # check for 'latest' tag in checkpoint folder
293
+ if (Path(checkpoint_folder) / "latest").exists():
294
+ latest_ckpt = (Path(checkpoint_folder) / "latest").read_text().strip()
295
+ else:
296
+ raise ValueError(
297
+ f"No checkpoint found in {checkpoint_folder}. Please provide a checkpoint."
298
+ )
299
+
300
+ checkpoint_path = str(
301
+ Path(checkpoint_folder) / latest_ckpt / "mp_rank_00_model_states.pt"
302
+ )
303
+ if not Path(checkpoint_path).exists():
304
+ raise ValueError(
305
+ f"No checkpoint found in {checkpoint_path}. Please provide a checkpoint."
306
+ )
307
+
308
+ return checkpoint_path
309
+
310
+
311
+ # [tensor_1, tensor_2], tensor_3, tensor_4 = to_cuda_half([tensor_1, tensor_2], tensor_3, tensor_4)
312
+ # probably not working yet
313
+ def to_cuda_half(*args):
314
+ cuda_half_args = []
315
+ for x in args:
316
+ if isinstance(x, list):
317
+ x_cuda_half = to_cuda_half(*x)
318
+ cuda_half_args.append(x_cuda_half)
319
+ elif isinstance(x, tuple):
320
+ x_cuda_half = to_cuda_half(*x)
321
+ cuda_half_args.append(x_cuda_half)
322
+ else:
323
+ if x.dtype in [torch.float32, torch.float16]:
324
+ cuda_half_args.append(x.cuda().half())
325
+ elif x.dtype == torch.long:
326
+ cuda_half_args.append(x.cuda())
327
+
328
+ if len(cuda_half_args) == 1:
329
+ return cuda_half_args[0]
330
+ else:
331
+ return cuda_half_args
332
+
333
+
334
+ def build_labels(
335
+ input_embeddings: TensorType["b", "s", "d"],
336
+ captions: TensorType["b", "s"],
337
+ eos_token,
338
+ device,
339
+ ) -> TensorType["b", "s"]:
340
+ """
341
+ Builds labels from input embeddings.
342
+
343
+ Masks out the labels with -100 in positions up to the seq length of the embeddings, so loss is only computed for captions,
344
+ and not for image tokens.
345
+ Additionally, masks out everything *after* the first eos token.
346
+ """
347
+ shape = input_embeddings.shape[:2] # b, s
348
+
349
+ assert captions.shape[1] >= shape[1]
350
+
351
+ # make sure to add masked embedding tokens in the appropriate locations in the labels
352
+ embedding_tokens = torch.zeros(shape, dtype=torch.int64).to(device) - 100
353
+ labels = torch.cat(
354
+ (embedding_tokens, captions[:, : -shape[1]]), dim=1
355
+ ) # we truncate the sequence length of the captions, as they are always padded to the full sequence length
356
+
357
+ # mask out repeating eos tokens
358
+ for label in labels:
359
+ for k, token in enumerate(label):
360
+ if token == eos_token:
361
+ label[k + 1 :] = -100
362
+ break
363
+
364
+ return labels
365
+
366
+
367
+ def is_url(string):
368
+ return string.startswith("http://") or string.startswith("https://")
369
+
370
+ def download_checkpoint(checkpoint_url, save_as):
371
+
372
+ gdown.download(url = checkpoint_url, output = save_as, quiet=False)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torchtyping
2
+ typeguard
3
+ git+https://github.com/finetuneanon/transformers.git#egg=transformers
4
+ gdown
5
+ tqdm
6
+ timm
7
+ git+https://github.com/openai/CLIP.git
8
+ deepspeed
9
+ wandb
test.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from magma import Magma
4
+ from magma.language_model import get_language_model
5
+ from magma.utils import get_tokenizer
6
+
7
+ if __name__ == "__main__":
8
+ # model = Magma.from_checkpoint(
9
+ # "configs/MAGMA_v1.yml",
10
+ # "/mnt/localdisk/mp_rank_00_model_states.pt",
11
+ # model_dir="/mnt/localdisk/gptj",
12
+ # lm_from_pretrained=True,
13
+ # )
14
+ # gptj_model = model.lm
15
+ # model.half().cuda().eval()
16
+ tokenizer = get_tokenizer()
17
+ input_text = tokenizer.encode("this is a test", return_tensors="pt").cuda()
18
+ input_img = torch.ones(1, 3, 384, 384).half().cuda()
19
+
20
+ # input = model.embed([input_img, input_text])
21
+ # logits = gptj_model(inputs_embeds=input).logits
22
+ # logits = logits.detach().cpu().numpy()
23
+ # np.save("/mnt/localdisk/logits_new.npy", logits)
24
+
25
+ from transformers import GPTJForCausalLM
26
+ import torch
27
+
28
+ # load new model
29
+ model = GPTJForCausalLM.from_pretrained(
30
+ "EleutherAI/gpt-j-6B",
31
+ revision="float16",
32
+ torch_dtype=torch.float16,
33
+ low_cpu_mem_usage=True,
34
+ )
35
+ model.cuda()
36
+
37
+ model.eval()
38
+
39
+ logits = model(input_text).logits
40
+ logits = logits.detach().cpu().numpy()
41
+ np.save("/mnt/localdisk/gptj_logits_new.npy", logits)
42
+
43
+ print("test")
train.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import deepspeed
4
+ import wandb
5
+ from torch.utils.data import random_split, ConcatDataset
6
+ from torch.optim import AdamW
7
+ from tqdm import tqdm
8
+ from functools import partial
9
+ from magma.datasets import (
10
+ collate_fn,
11
+ ImgCptDataset,
12
+ )
13
+ from magma.magma import (
14
+ Magma,
15
+ )
16
+ from magma.utils import (
17
+ is_main,
18
+ cycle,
19
+ parse_args,
20
+ wandb_log,
21
+ wandb_init,
22
+ save_model,
23
+ load_model,
24
+ print_main,
25
+ configure_param_groups,
26
+ )
27
+ from magma.train_loop import (
28
+ eval_step,
29
+ inference_step,
30
+ train_step,
31
+ )
32
+
33
+
34
+ def _load_img_cpt_datasets(dataset_dir, tokenizer, transforms):
35
+ if isinstance(dataset_dir, (list, tuple)):
36
+ return ConcatDataset(
37
+ [_load_img_cpt_datasets(d, tokenizer, transforms) for d in dataset_dir]
38
+ )
39
+ elif isinstance(dataset_dir, str):
40
+ return ImgCptDataset(dataset_dir, tokenizer=tokenizer, transforms=transforms)
41
+ else:
42
+ raise TypeError("dataset dir wrong type")
43
+
44
+
45
+ def get_pretraining_datasets(config, tokenizer, transforms):
46
+ # if config.train_dataset_dir is a list, load all datasets + join together
47
+ train_dataset = _load_img_cpt_datasets(
48
+ config.train_dataset_dir, tokenizer, transforms
49
+ )
50
+ # if no dedicated eval sets are given, use a percentage of the train dataset
51
+ if config.eval_dataset_dir is None:
52
+ eval_len = int(len(train_dataset) * config.eval_dataset_pct)
53
+ train_len = len(train_dataset) - eval_len
54
+ print(
55
+ f"Randomly splitting train_dataset into two datasets of length {train_len} and {eval_len}"
56
+ )
57
+ train_dataset, eval_dataset = random_split(train_dataset, [train_len, eval_len])
58
+ else:
59
+ eval_dataset = _load_img_cpt_datasets(
60
+ config.eval_dataset_dir, tokenizer, transforms
61
+ )
62
+
63
+ print_main(f"Loaded train dataset with {len(train_dataset)} samples")
64
+ print_main(f"Loaded eval dataset with {len(eval_dataset)} samples")
65
+
66
+ return train_dataset, eval_dataset
67
+
68
+
69
+ # tell tokenizers not to do parallelism
70
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
71
+
72
+ if __name__ == "__main__":
73
+
74
+ # parse command line arguments:
75
+ args = parse_args()
76
+ deepspeed.init_distributed()
77
+
78
+ # load model + tokenizer:
79
+ model = Magma(
80
+ args.config
81
+ ) # for finetuning one might want to load the model via Magma.from_checkpoint(...) here
82
+ tokenizer, config, transforms = model.tokenizer, model.config, model.transforms
83
+
84
+ # filter frozen from trainable parameters:
85
+ trainable_parameters = configure_param_groups(model, config)
86
+
87
+ # load data:
88
+ train_dataset, eval_dataset = get_pretraining_datasets(
89
+ config, tokenizer, transforms
90
+ )
91
+
92
+ print_main(f"Loaded train dataset with {len(train_dataset)} samples")
93
+ print_main(f"Loaded eval dataset with {len(eval_dataset)} samples")
94
+
95
+ opt = AdamW(
96
+ trainable_parameters,
97
+ config.lr,
98
+ betas=(0.9, 0.95),
99
+ weight_decay=config.weight_decay,
100
+ )
101
+
102
+ model_engine, opt, train_loader, lr_scheduler = deepspeed.initialize(
103
+ args=args,
104
+ model=model,
105
+ optimizer=opt,
106
+ model_parameters=trainable_parameters,
107
+ training_data=train_dataset,
108
+ collate_fn=partial(collate_fn, seq_len=model.seq_len),
109
+ config_params=config.deepspeed_config_params,
110
+ )
111
+ eval_loader = cycle(model_engine.deepspeed_io(eval_dataset))
112
+ train_loader = cycle(train_loader)
113
+
114
+ # initialize training
115
+ global_step = 0
116
+ if config.load:
117
+ # loads a deepspeed checkpoint if provided. For finetuning, set load_optimizer to false
118
+ previous_global_step = load_model(
119
+ model_engine,
120
+ config.load,
121
+ load_optimizer_states=config.load_optimizer,
122
+ load_lr_scheduler_states=config.load_optimizer,
123
+ )
124
+
125
+ if config.load_optimizer:
126
+ global_step = previous_global_step
127
+
128
+ pbar = tqdm(
129
+ range(0, config.train_steps),
130
+ desc="training...",
131
+ initial=global_step,
132
+ total=config.train_steps,
133
+ disable=not is_main(),
134
+ )
135
+ wandb_init(
136
+ project=config.wandb_project,
137
+ name=config.name or wandb.util.generate_id(),
138
+ config=config,
139
+ )
140
+
141
+ # training loop
142
+ for i in pbar:
143
+ if global_step >= config.train_steps:
144
+ break
145
+
146
+ ##### train step
147
+ loss = train_step(config, train_loader, model_engine)
148
+
149
+ global_step += 1
150
+
151
+ if global_step % config.log_every == 0:
152
+ pbar.set_description(f"training... Step: {global_step} Loss: {loss}")
153
+ current_lr = (
154
+ [lr for lr in lr_scheduler.get_lr()]
155
+ if lr_scheduler is not None
156
+ else config.lr
157
+ )
158
+ to_log = {"train/loss": loss, "train/lr": current_lr}
159
+ wandb_log(to_log, step=global_step)
160
+
161
+ ##### Evaluation phase
162
+ if global_step % config.eval_every == 0:
163
+ model_engine.eval()
164
+ with torch.no_grad():
165
+
166
+ ##### eval step:
167
+ eval_loss = eval_step(config, eval_loader, model_engine)
168
+
169
+ wandb_log({"eval/loss": eval_loss}, step=global_step)
170
+ pbar.set_description(
171
+ f"evaluating... Step: {global_step} Eval Loss: {eval_loss}"
172
+ )
173
+
174
+ ##### inference:
175
+ image_grid, caption = inference_step(config, eval_loader, model_engine)
176
+ wandb_log(
177
+ {"inference/image": wandb.Image(image_grid, caption=caption)},
178
+ step=global_step,
179
+ )
180
+
181
+ model_engine.train()
182
+
183
+ ##### Save model
184
+ if global_step % config.save_every == 0:
185
+ if config.save is not None:
186
+ save_model(model_engine, config.save, global_step)
187
+ print_main(f"saving model at step {global_step}")
188
+
189
+ ##### Save model after training is finished
190
+ if config.save is not None:
191
+ save_model(model_engine, config.save, global_step)
192
+ print_main(f"saving model at end of training (step {global_step})")