chrisc36 commited on
Commit
3f119eb
·
verified ·
1 Parent(s): 91f94e3

Add files using upload-large-folder tool

Browse files
Files changed (4) hide show
  1. config.json +0 -8
  2. config_molmoe.py +48 -291
  3. example.py +55 -0
  4. modeling_molmoe.py +14 -94
config.json CHANGED
@@ -95,14 +95,6 @@
95
  "rope_theta": 10000.0,
96
  "scale_logits": false,
97
  "system_prompt_kind": "demo_or_style",
98
- "tokenizer": {
99
- "identifier": "allenai/gpt-neox-olmo-dolma-v1_5",
100
- "olmo_bos_token_id": null,
101
- "olmo_eos_token_id": null,
102
- "tokenizer_adds_space": false,
103
- "tokenizer_dir": null,
104
- "truncate_direction": "right"
105
- },
106
  "transformers_version": "4.45.0.dev0",
107
  "unconditioned": false,
108
  "use_cache": true,
 
95
  "rope_theta": 10000.0,
96
  "scale_logits": false,
97
  "system_prompt_kind": "demo_or_style",
 
 
 
 
 
 
 
 
98
  "transformers_version": "4.45.0.dev0",
99
  "unconditioned": false,
100
  "use_cache": true,
config_molmoe.py CHANGED
@@ -2,7 +2,9 @@ from __future__ import annotations
2
 
3
  import logging
4
  from dataclasses import asdict, dataclass, field
 
5
  from glob import glob
 
6
  from pathlib import Path
7
  from typing import (
8
  Any,
@@ -17,168 +19,36 @@ from typing import (
17
  cast,
18
  )
19
 
20
- import torch
21
  from transformers import PretrainedConfig
22
- from omegaconf import DictConfig, ListConfig, OmegaConf
23
- from omegaconf import OmegaConf as om
24
- from omegaconf.errors import OmegaConfBaseException
25
- from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
26
- import gin
27
-
28
- #from olmo.aliases import PathOrStr
29
- from .aliases import PathOrStr
30
- #from olmo.exceptions import OLMoConfigurationError
31
- from .exceptions import OLMoConfigurationError
32
- #from olmo.util import StrEnum, resource_path
33
- from .util import StrEnum, resource_path
34
-
35
- #from olmo.mm_data.data_utils import build_tokenizer
36
- from .data_utils import build_tokenizer
37
- #from olmo.multimodal_preprocessor import MultiModalPreprocessor
38
- from .multimodal_preprocessor import MultiModalPreprocessor
39
-
40
- __all__ = [
41
- "ActivationType",
42
- "ActivationCheckpointingStrategy",
43
- "BlockType",
44
- "LayerNormType",
45
- "VisionBackboneType",
46
- "VisionBackboneConfig",
47
- "InitFnType",
48
- "ModelConfig",
49
- "OptimizerType",
50
- "OptimizerConfig",
51
- "SchedulerType",
52
- "SchedulerConfig",
53
- "DataConfig",
54
- "InstanceFilterConfig",
55
- "EvaluatorConfig",
56
- "TokenizerConfig",
57
- "TrainConfig",
58
- "PaddingDirection",
59
- "TruncationDirection",
60
- "SpeedMonitorConfig",
61
- "WandbConfig",
62
- "CompilerConfig",
63
- "WandbConfig",
64
- "FSDPPrecision",
65
- "FSDPWrapStrategy",
66
- "FSDPConfig",
67
- "CheckpointType",
68
- ]
69
 
70
  C = TypeVar("C", bound="BaseConfig")
71
  D = TypeVar("D", bound="DictConfig|ListConfig")
72
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  class AttentionType(StrEnum):
75
  sdpa = "sdpa"
76
  direct = "direct"
77
  flash = "flash"
78
 
79
 
80
- class BaseConfig:
81
- @classmethod
82
- def _register_resolvers(cls, validate_paths: bool = True):
83
- # Expands path globs into a list.
84
- def path_glob(*paths) -> List[str]:
85
- out = []
86
- for path in paths:
87
- matches = sorted(glob(path))
88
- if not matches and validate_paths:
89
- raise FileNotFoundError(f"{path} does not match any files or dirs")
90
- out.extend(matches)
91
- return out
92
-
93
- # Chooses the first path in the arguments that exists.
94
- def path_choose(*paths) -> str:
95
- from .util import is_url
96
-
97
- for path in paths:
98
- if is_url(path) or Path(path).exists():
99
- return path
100
- if validate_paths:
101
- raise FileNotFoundError(", ".join(paths))
102
- else:
103
- return ""
104
-
105
- # Finds the latest checkpoint in a folder.
106
- def path_last_checkpoint(path) -> str:
107
- from .util import find_latest_checkpoint
108
-
109
- latest_checkpoint = find_latest_checkpoint(path)
110
- if latest_checkpoint is None:
111
- if validate_paths:
112
- raise FileNotFoundError(f"Could not find a latest checkpoint at {path}")
113
- else:
114
- return ""
115
- else:
116
- return str(latest_checkpoint)
117
-
118
- om.register_new_resolver("path.glob", path_glob, replace=True)
119
- om.register_new_resolver("path.choose", path_choose, replace=True)
120
- om.register_new_resolver("path.last_checkpoint", path_last_checkpoint, replace=True)
121
-
122
- @classmethod
123
- def update_legacy_settings(cls, config: D) -> D:
124
- """
125
- Update the legacy config settings whose schemas have undergone backwards-incompatible changes.
126
- """
127
- return config
128
-
129
- @classmethod
130
- def new(cls: Type[C], **kwargs) -> C:
131
- cls._register_resolvers()
132
- conf = om.structured(cls)
133
- try:
134
- if kwargs:
135
- conf = om.merge(conf, kwargs)
136
- return cast(C, om.to_object(conf))
137
- except OmegaConfBaseException as e:
138
- raise OLMoConfigurationError(str(e))
139
-
140
- @classmethod
141
- def load(
142
- cls: Type[C],
143
- path: PathOrStr,
144
- overrides: Optional[List[str]] = None,
145
- key: Optional[str] = None,
146
- validate_paths: bool = True,
147
- ) -> C:
148
- """Load from a YAML file."""
149
- cls._register_resolvers(validate_paths=validate_paths)
150
- schema = om.structured(cls)
151
- try:
152
- raw = om.load(str(path))
153
-
154
- # Backwards compatibility hack, we need this here not in `update_legacy_settings`
155
- # since it has to be applied before selecting with `key`
156
- if "tokenizer" in raw and "model" in raw:
157
- raw["model"]["tokenizer"] = raw.pop("tokenizer")
158
-
159
- if key is not None:
160
- raw = raw[key] # type: ignore
161
- raw = cls.update_legacy_settings(raw)
162
- conf = om.merge(schema, raw)
163
- if overrides:
164
- conf = om.merge(conf, om.from_dotlist(overrides))
165
- return cast(C, om.to_object(conf))
166
- except OmegaConfBaseException as e:
167
- raise OLMoConfigurationError(str(e))
168
-
169
- def save(self, path: PathOrStr) -> None:
170
- """Save to a YAML file."""
171
- om.save(config=self, f=str(path))
172
-
173
- def asdict(self, exclude: Optional[Iterable[str]] = None) -> Dict[str, Any]:
174
- out = asdict(self) # type: ignore
175
- if exclude is not None:
176
- for name in exclude:
177
- if name in out:
178
- del out[name]
179
- return out
180
-
181
-
182
  class LayerNormType(StrEnum):
183
  default = "default"
184
  """
@@ -290,7 +160,7 @@ class ImageProjectType(StrEnum):
290
 
291
 
292
  @dataclass
293
- class VisionBackboneConfig(BaseConfig):
294
  image_model_type: VisionBackboneType = VisionBackboneType.openai
295
  image_default_input_size: Tuple[int, int] = (336, 336)
296
  image_patch_size: int = 14
@@ -328,18 +198,7 @@ class TruncationDirection(StrEnum):
328
 
329
 
330
  @dataclass
331
- class TokenizerConfig(BaseConfig):
332
- identifier: str = "gpt2"
333
- truncate_direction: TruncationDirection = TruncationDirection.right
334
- # Does the tokenizer automatically start input text with a space
335
- tokenizer_adds_space: Optional[bool] = False
336
- tokenizer_dir: Optional[str] = None # tokenizer directory if using a seqio tokenizer
337
- olmo_bos_token_id: Optional[int] = None
338
- olmo_eos_token_id: Optional[int] = None
339
-
340
-
341
- @dataclass
342
- class ModelConfig(BaseConfig):
343
  """
344
  OLMo (model) configuration.
345
  """
@@ -429,11 +288,6 @@ class ModelConfig(BaseConfig):
429
 
430
  rope_impl: str = "cockatoo"
431
 
432
- vision_backbone: Optional[VisionBackboneConfig] = None
433
- """
434
- Vision backbone settings for multi-modal models.
435
- """
436
-
437
  vit_load_path: Optional[str] = None
438
  """
439
  Use this to load the vit model.
@@ -749,129 +603,10 @@ class ModelConfig(BaseConfig):
749
  Used for Gemma-2.
750
  """
751
 
752
- tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig)
753
- """
754
- Tokenizer configuration.
755
- """
756
-
757
  loss_token_weighting: Optional[str] = None
758
 
759
  gin_bindings: Optional[str] = None
760
 
761
- def get_tokenizer(self):
762
- tokenizer_cfg = self.tokenizer
763
- assert tokenizer_cfg.identifier.startswith("mm:")
764
- kargs = {}
765
- if tokenizer_cfg.identifier[3:].startswith("olmo-"):
766
- kargs["olmo_bos_token_id"] = tokenizer_cfg.olmo_bos_token_id
767
- kargs["olmo_eos_token_id"] = tokenizer_cfg.olmo_eos_token_id
768
- return build_tokenizer(
769
- tokenizer_cfg.identifier[3:],
770
- adds_space=tokenizer_cfg.tokenizer_adds_space,
771
- tokenizer_dir=tokenizer_cfg.tokenizer_dir,
772
- pad_tokenizer_to=self.vocab_size if self.pad_tokenizer else None,
773
- **kargs
774
- )
775
-
776
- def get_preprocessor(self):
777
- vision_cfg = self.vision_backbone
778
- h, w = self.llm_patches_per_crop()
779
-
780
- return MultiModalPreprocessor(
781
- loss_token_weighting=self.loss_token_weighting,
782
- always_start_with_space=self.always_start_with_space,
783
- tokenizer=self.get_tokenizer(),
784
- prompt_override=self.prompt_override,
785
- fix_image_input_idx=self.fix_image_input_idx,
786
- prompt_templates=self.prompt_type,
787
- system_prompt=self.system_prompt_kind,
788
- default_inference_len=self.default_inference_len,
789
- message_format=self.message_formatting,
790
- unconditioned=self.unconditioned,
791
- crop_mode=self.crop_mode,
792
- max_crops=self.max_crops,
793
- do_random_scale=self.do_random_scale,
794
- base_image_input_size=vision_cfg.image_default_input_size,
795
- image_patch_size=vision_cfg.image_patch_size,
796
- image_token_length_h=h,
797
- image_token_length_w=w,
798
- use_col_tokens=self.use_col_tokens,
799
- overlap_margins=self.overlap_margins,
800
- image_padding_mask=self.image_padding_embed is not None
801
- )
802
-
803
- def __post_init__(self):
804
- self.vit_layers = tuple(self.vit_layers) # type: ignore[assignment]
805
-
806
- @classmethod
807
- def update_legacy_settings(cls, config: D) -> D:
808
- """
809
- Update the legacy config settings whose schemas have undergone backwards-incompatible changes.
810
- """
811
- if "flash_attention" in config:
812
- is_flash = config.flash_attention
813
- del config.flash_attention
814
- config.attention_type = AttentionType.flash if is_flash else AttentionType.sdpa
815
-
816
- if "bos_token_id" in config:
817
- config.tokenizer.olmo_bos_token_id = config.pop("bos_token_id")
818
- config.tokenizer.olmo_eos_token_id = config.pop("eos_token_id")
819
-
820
- if "image_padding_mask" in config:
821
- assert not config["image_padding_mask"]
822
- del config["image_padding_mask"]
823
- config["image_padding_embed"] = None
824
- elif "image_padding_embed" not in config:
825
- config["image_padding_embed"] = None
826
- return config
827
-
828
- @property
829
- def effective_n_kv_heads(self) -> int:
830
- if self.n_kv_heads is None:
831
- if self.multi_query_attention is True:
832
- return 1
833
- else:
834
- return self.n_heads
835
- else:
836
- if self.multi_query_attention is None:
837
- return self.n_kv_heads
838
- if self.multi_query_attention:
839
- n_kv_heads_should_be = 1
840
- else:
841
- n_kv_heads_should_be = self.n_heads
842
- if self.n_kv_heads == n_kv_heads_should_be:
843
- return n_kv_heads_should_be
844
- else:
845
- raise OLMoConfigurationError(
846
- "You can't set `multi_query_attention` and `n_kv_heads` at the same time."
847
- )
848
-
849
- @property
850
- def image_num_patch(self):
851
- assert self.vision_backbone is not None
852
- return self.vision_backbone.image_num_patch
853
-
854
- @property
855
- def image_patch_size(self):
856
- assert self.vision_backbone is not None
857
- return self.visoin_backbone.image_patch_size
858
-
859
- def llm_patches_per_crop(self):
860
- h, w = self.image_num_patch
861
- # Round up in case we need to pad the image features for pooling
862
- h = (h + self.image_pooling_h - 1) // self.image_pooling_h
863
- w = (w + self.image_pooling_w - 1) // self.image_pooling_w
864
- return h, w
865
-
866
- def get_max_crops(self) -> int:
867
- """Max numbers of that can be built for one image"""
868
- if self.crop_mode == "resize":
869
- return 1
870
- elif "resize" in self.crop_mode:
871
- return 1 + self.max_crops
872
- else:
873
- return self.max_crops
874
-
875
 
876
  class MolmoConfig(PretrainedConfig):
877
  model_type = "molmo"
@@ -879,7 +614,7 @@ class MolmoConfig(PretrainedConfig):
879
 
880
  def __init__(self, use_cache: bool = False, **kwargs):
881
  model_config = ModelConfig()
882
- all_kwargs = model_config.asdict()
883
  all_kwargs.update(kwargs)
884
  all_kwargs.update({"use_cache": use_cache})
885
  all_kwargs.update(
@@ -901,8 +636,8 @@ class MolmoConfig(PretrainedConfig):
901
 
902
  @property
903
  def image_num_patch(self):
904
- assert self.vision_backbone is not None
905
- return self.vision_backbone.image_num_patch
906
 
907
  @property
908
  def llm_patches_per_crop(self):
@@ -910,4 +645,26 @@ class MolmoConfig(PretrainedConfig):
910
  # Round up in case we need to pad the image features for pooling
911
  h = (h + self.image_pooling_h - 1) // self.image_pooling_h
912
  w = (w + self.image_pooling_w - 1) // self.image_pooling_w
913
- return h, w
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import logging
4
  from dataclasses import asdict, dataclass, field
5
+ from enum import Enum
6
  from glob import glob
7
+ from os import PathLike
8
  from pathlib import Path
9
  from typing import (
10
  Any,
 
19
  cast,
20
  )
21
 
 
22
  from transformers import PretrainedConfig
23
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  C = TypeVar("C", bound="BaseConfig")
26
  D = TypeVar("D", bound="DictConfig|ListConfig")
27
 
28
 
29
+ PathOrStr = Union[str, PathLike]
30
+
31
+
32
+ class StrEnum(str, Enum):
33
+ """
34
+ This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
35
+ We include this here for compatibility with older version of Python.
36
+ """
37
+
38
+ def __str__(self) -> str:
39
+ return self.value
40
+
41
+ def __repr__(self) -> str:
42
+ return f"'{str(self)}'"
43
+
44
+
45
+
46
  class AttentionType(StrEnum):
47
  sdpa = "sdpa"
48
  direct = "direct"
49
  flash = "flash"
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  class LayerNormType(StrEnum):
53
  default = "default"
54
  """
 
160
 
161
 
162
  @dataclass
163
+ class VisionBackboneConfig:
164
  image_model_type: VisionBackboneType = VisionBackboneType.openai
165
  image_default_input_size: Tuple[int, int] = (336, 336)
166
  image_patch_size: int = 14
 
198
 
199
 
200
  @dataclass
201
+ class ModelConfig:
 
 
 
 
 
 
 
 
 
 
 
202
  """
203
  OLMo (model) configuration.
204
  """
 
288
 
289
  rope_impl: str = "cockatoo"
290
 
 
 
 
 
 
291
  vit_load_path: Optional[str] = None
292
  """
293
  Use this to load the vit model.
 
603
  Used for Gemma-2.
604
  """
605
 
 
 
 
 
 
606
  loss_token_weighting: Optional[str] = None
607
 
608
  gin_bindings: Optional[str] = None
609
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
610
 
611
  class MolmoConfig(PretrainedConfig):
612
  model_type = "molmo"
 
614
 
615
  def __init__(self, use_cache: bool = False, **kwargs):
616
  model_config = ModelConfig()
617
+ all_kwargs = asdict(model_config)
618
  all_kwargs.update(kwargs)
619
  all_kwargs.update({"use_cache": use_cache})
620
  all_kwargs.update(
 
636
 
637
  @property
638
  def image_num_patch(self):
639
+ h, w = (336, 336)
640
+ return h // 14, w // 14
641
 
642
  @property
643
  def llm_patches_per_crop(self):
 
645
  # Round up in case we need to pad the image features for pooling
646
  h = (h + self.image_pooling_h - 1) // self.image_pooling_h
647
  w = (w + self.image_pooling_w - 1) // self.image_pooling_w
648
+ return h, w
649
+
650
+ @property
651
+ def effective_n_kv_heads(self) -> int:
652
+ if self.n_kv_heads is None:
653
+ if self.multi_query_attention is True:
654
+ return 1
655
+ else:
656
+ return self.n_heads
657
+ else:
658
+ if self.multi_query_attention is None:
659
+ return self.n_kv_heads
660
+ if self.multi_query_attention:
661
+ n_kv_heads_should_be = 1
662
+ else:
663
+ n_kv_heads_should_be = self.n_heads
664
+ if self.n_kv_heads == n_kv_heads_should_be:
665
+ return n_kv_heads_should_be
666
+ else:
667
+ raise ValueError(
668
+ "You can't set `multi_query_attention` and `n_kv_heads` at the same time."
669
+ )
670
+
example.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig
2
+ from PIL import Image
3
+ import requests
4
+
5
+
6
+ def main():
7
+ load_path = "."
8
+
9
+ # load the processor
10
+ print("Loading processor")
11
+ processor = AutoProcessor.from_pretrained(
12
+ load_path,
13
+ trust_remote_code=True,
14
+ torch_dtype='auto',
15
+ device_map='auto'
16
+ )
17
+
18
+ # load the model
19
+ print("Loading model")
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ load_path,
22
+ trust_remote_code=True,
23
+ torch_dtype='auto',
24
+ device_map='auto'
25
+ )
26
+
27
+ # process the image and text
28
+ print("Processing...")
29
+ inputs = processor.process(
30
+ images=[Image.open(requests.get("https://picsum.photos/id/237/536/354", stream=True).raw)],
31
+ text="Describe this image."
32
+ )
33
+
34
+ # move inputs to the correct device and make a batch of size 1
35
+ inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
36
+
37
+ # generate output; maximum 200 new tokens; stop generation when <|endoftext|> is generated
38
+ print("Generating....")
39
+ output = model.generate_from_batch(
40
+ inputs,
41
+ GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
42
+ tokenizer=processor.tokenizer
43
+ )
44
+
45
+ # only get generated tokens; decode them to text
46
+ generated_tokens = output[0,inputs['input_ids'].size(1):]
47
+ generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
48
+
49
+ # print the generated text
50
+ print(generated_text)
51
+
52
+
53
+
54
+ if __name__ == '__main__':
55
+ main()
modeling_molmoe.py CHANGED
@@ -27,7 +27,7 @@ from typing import (
27
  Set,
28
  Tuple,
29
  cast,
30
- Union,
31
  )
32
  from copy import deepcopy
33
  import torch
@@ -36,17 +36,10 @@ import torch.nn as nn
36
  import torch.nn.functional as F
37
  from torch import einsum
38
  import einops
39
- from transformers import PreTrainedModel
40
  from transformers.modeling_outputs import CausalLMOutputWithPast
41
 
42
- from .aliases import PathOrStr
43
- from .beam_search import (
44
- BeamSearch,
45
- Constraint,
46
- FinalSequenceScorer,
47
- Sampler
48
- )
49
- from .config import (
50
  ActivationType,
51
  BlockType,
52
  LayerNormType,
@@ -56,10 +49,10 @@ from .config import (
56
  AttentionType,
57
  )
58
 
59
- from .util import resource_path
60
  from .config_molmoe import (
61
  MolmoConfig,
62
- VisionBackboneConfig
63
  )
64
 
65
  if sys.version_info.minor > 8:
@@ -69,26 +62,14 @@ elif sys.version_info.minor == 8:
69
  else:
70
  raise SystemExit("This script supports Python 3.8 or higher")
71
 
72
- __all__ = [
73
- "LayerNormBase",
74
- "LayerNorm",
75
- "RMSLayerNorm",
76
- "RotaryEmbedding",
77
- "Activation",
78
- "GELU",
79
- "ReLU",
80
- "SwiGLU",
81
- "OLMoBlock",
82
- "OLMoSequentialBlock",
83
- "OLMo",
84
- "OLMoOutput",
85
- "OLMoGenerateOutput",
86
- ]
87
-
88
 
89
  log = logging.getLogger(__name__)
90
 
91
 
 
 
 
 
92
  def activation_checkpoint_function(cfg: ModelConfig):
93
  preserve_rng_state = not (
94
  (cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and
@@ -114,20 +95,6 @@ def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: b
114
  x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
115
 
116
 
117
- def activation_checkpoint_function(cfg: MolmoConfig):
118
- preserve_rng_state = not (
119
- (cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and
120
- (cfg.residual_dropout == 0.0) and (cfg.response_residual_dropout == 0.0)
121
- )
122
- from torch.utils.checkpoint import checkpoint
123
-
124
- return partial(
125
- checkpoint,
126
- preserve_rng_state=True,
127
- use_reentrant=False,
128
- )
129
-
130
-
131
  def vit_activation_checkpoint_function(cfg: MolmoConfig):
132
  v_cfg = cfg.vision_backbone
133
  preserve_rng_state = (
@@ -142,22 +109,6 @@ def vit_activation_checkpoint_function(cfg: MolmoConfig):
142
  )
143
 
144
 
145
- def should_checkpoint_block(strategy: Optional[ActivationCheckpointingStrategy], block_idx: int) -> bool:
146
- if strategy is None:
147
- return False
148
- elif (
149
- (strategy == ActivationCheckpointingStrategy.whole_layer)
150
- or (strategy == ActivationCheckpointingStrategy.one_in_two and block_idx % 2 == 0)
151
- or (strategy == ActivationCheckpointingStrategy.one_in_three and block_idx % 3 == 0)
152
- or (strategy == ActivationCheckpointingStrategy.one_in_four and block_idx % 4 == 0)
153
- or (strategy == ActivationCheckpointingStrategy.two_in_three and block_idx % 3 != 0)
154
- or (strategy == ActivationCheckpointingStrategy.three_in_four and block_idx % 4 != 0)
155
- ):
156
- return True
157
- else:
158
- return False
159
-
160
-
161
  class BufferCache(dict, MutableMapping[str, torch.Tensor]):
162
  """
163
  Cache for attention biases and other things that would normally be stored as buffers.
@@ -1557,15 +1508,11 @@ class MolmoVisionBackbone(nn.Module):
1557
  self.image_feature_dropout = Dropout(config.image_feature_dropout)
1558
 
1559
  @classmethod
1560
- def build(cls, config: MolmoConfig) -> OLMoVisionBackbone:
1561
  v_cfg = config.vision_backbone
1562
  assert v_cfg is not None
1563
  return MolmoPretrainedVisionBackbone(config)
1564
 
1565
- @abstractmethod
1566
- def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
1567
- raise NotImplementedError()
1568
-
1569
  def reset_parameters(self):
1570
  if self.image_pooling_2d is not None:
1571
  self.image_pooling_2d.reset_parameters()
@@ -1583,9 +1530,9 @@ class MolmoVisionBackbone(nn.Module):
1583
 
1584
 
1585
  class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
1586
- def __init__(self, config: MolmoVisionBackboneConfig):
1587
  super().__init__(config)
1588
- v_cfg = self.config.vision_backbone
1589
 
1590
  if v_cfg.image_model_type == VisionBackboneType.openai:
1591
  self.image_vit = VisionTransformer(config)
@@ -1640,11 +1587,6 @@ class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
1640
  if self.config.use_cls_feature:
1641
  nn.init.xavier_uniform_(self.cls_projector.weight)
1642
 
1643
- def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
1644
- self.grad_checkpointing = True
1645
- if strategy in (ActivationCheckpointingStrategy.whole_layer, ActivationCheckpointingStrategy.vit_only):
1646
- self.image_vit.set_grad_checkpointing()
1647
-
1648
  def encode_image(self, images: torch.Tensor) -> torch.Tensor:
1649
  """
1650
  : param images: (batch_size, num_crops, num_patch, n_pixels)
@@ -1802,9 +1744,6 @@ class MolmoModel(MolmoPretrainedModel):
1802
  "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
1803
  )
1804
 
1805
- self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
1806
- self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)
1807
-
1808
  if not (
1809
  0 < self.config.block_group_size <= self.config.n_layers
1810
  and self.config.n_layers % self.config.block_group_size == 0
@@ -1846,25 +1785,14 @@ class MolmoModel(MolmoPretrainedModel):
1846
  ]
1847
  self.transformer.update({"blocks": nn.ModuleList(layers)})
1848
 
1849
- self.vision_backbone: Optional[OLMoVisionBackbone] = None
1850
  if config.vision_backbone is not None:
1851
  self.vision_backbone = MolmoVisionBackbone.build(config)
1852
 
1853
  if self.vision_backbone is not None:
1854
  self.vision_backbone.reset_with_pretrained_weights()
1855
 
1856
- def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
1857
- self.activation_checkpointing_strategy = strategy
1858
- if self.config.block_group_size != 1:
1859
- for block_group in self.transformer.block_groups:
1860
- block_group.set_activation_checkpointing(strategy)
1861
- else:
1862
- for block in self.transformer.blocks:
1863
- block.set_activation_checkpointing(strategy)
1864
 
1865
- if self.vision_backbone is not None:
1866
- self.vision_backbone.set_activation_checkpointing(strategy)
1867
-
1868
  @property
1869
  def device(self) -> torch.device:
1870
  device: torch.device = self.transformer.wte.weight.device # type: ignore
@@ -1873,7 +1801,6 @@ class MolmoModel(MolmoPretrainedModel):
1873
  else:
1874
  return device
1875
 
1876
-
1877
  def forward(
1878
  self,
1879
  input_ids: torch.LongTensor,
@@ -2069,14 +1996,7 @@ class MolmoModel(MolmoPretrainedModel):
2069
  all_hidden_states.append(x)
2070
 
2071
  layer_past = None if past_key_values is None else past_key_values[block_idx]
2072
- if should_checkpoint_block(self.activation_checkpointing_strategy, block_idx):
2073
- # shape: (batch_size, seq_len, d_model)
2074
- x, cache = self._activation_checkpoint_fn(
2075
- layer, x, attention_bias=attention_bias, position_ids=position_ids, drop_mask=response_mask, layer_past=layer_past, use_cache=use_cache
2076
- )
2077
- else:
2078
- # shape: (batch_size, seq_len, d_model)
2079
- x, cache = layer(x, attention_bias=attention_bias, position_ids=position_ids, drop_mask=response_mask, layer_past=layer_past, use_cache=use_cache)
2080
 
2081
  if attn_key_values is not None:
2082
  assert cache is not None
 
27
  Set,
28
  Tuple,
29
  cast,
30
+ Union, Any,
31
  )
32
  from copy import deepcopy
33
  import torch
 
36
  import torch.nn.functional as F
37
  from torch import einsum
38
  import einops
39
+ from transformers import PreTrainedModel, GenerationConfig, Cache
40
  from transformers.modeling_outputs import CausalLMOutputWithPast
41
 
42
+ from .config_molmoe import (
 
 
 
 
 
 
 
43
  ActivationType,
44
  BlockType,
45
  LayerNormType,
 
49
  AttentionType,
50
  )
51
 
52
+
53
  from .config_molmoe import (
54
  MolmoConfig,
55
+ VisionBackboneConfig, ModelConfig
56
  )
57
 
58
  if sys.version_info.minor > 8:
 
62
  else:
63
  raise SystemExit("This script supports Python 3.8 or higher")
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  log = logging.getLogger(__name__)
67
 
68
 
69
+ class OLMoConfigurationError(Exception):
70
+ pass
71
+
72
+
73
  def activation_checkpoint_function(cfg: ModelConfig):
74
  preserve_rng_state = not (
75
  (cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and
 
95
  x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
96
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def vit_activation_checkpoint_function(cfg: MolmoConfig):
99
  v_cfg = cfg.vision_backbone
100
  preserve_rng_state = (
 
109
  )
110
 
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  class BufferCache(dict, MutableMapping[str, torch.Tensor]):
113
  """
114
  Cache for attention biases and other things that would normally be stored as buffers.
 
1508
  self.image_feature_dropout = Dropout(config.image_feature_dropout)
1509
 
1510
  @classmethod
1511
+ def build(cls, config: MolmoConfig):
1512
  v_cfg = config.vision_backbone
1513
  assert v_cfg is not None
1514
  return MolmoPretrainedVisionBackbone(config)
1515
 
 
 
 
 
1516
  def reset_parameters(self):
1517
  if self.image_pooling_2d is not None:
1518
  self.image_pooling_2d.reset_parameters()
 
1530
 
1531
 
1532
  class MolmoPretrainedVisionBackbone(MolmoVisionBackbone):
1533
+ def __init__(self, config: MolmoConfig):
1534
  super().__init__(config)
1535
+ v_cfg = VisionBackboneConfig()
1536
 
1537
  if v_cfg.image_model_type == VisionBackboneType.openai:
1538
  self.image_vit = VisionTransformer(config)
 
1587
  if self.config.use_cls_feature:
1588
  nn.init.xavier_uniform_(self.cls_projector.weight)
1589
 
 
 
 
 
 
1590
  def encode_image(self, images: torch.Tensor) -> torch.Tensor:
1591
  """
1592
  : param images: (batch_size, num_crops, num_patch, n_pixels)
 
1744
  "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
1745
  )
1746
 
 
 
 
1747
  if not (
1748
  0 < self.config.block_group_size <= self.config.n_layers
1749
  and self.config.n_layers % self.config.block_group_size == 0
 
1785
  ]
1786
  self.transformer.update({"blocks": nn.ModuleList(layers)})
1787
 
1788
+ self.vision_backbone: Optional[MolmoVisionBackbone] = None
1789
  if config.vision_backbone is not None:
1790
  self.vision_backbone = MolmoVisionBackbone.build(config)
1791
 
1792
  if self.vision_backbone is not None:
1793
  self.vision_backbone.reset_with_pretrained_weights()
1794
 
 
 
 
 
 
 
 
 
1795
 
 
 
 
1796
  @property
1797
  def device(self) -> torch.device:
1798
  device: torch.device = self.transformer.wte.weight.device # type: ignore
 
1801
  else:
1802
  return device
1803
 
 
1804
  def forward(
1805
  self,
1806
  input_ids: torch.LongTensor,
 
1996
  all_hidden_states.append(x)
1997
 
1998
  layer_past = None if past_key_values is None else past_key_values[block_idx]
1999
+ x, cache = layer(x, attention_bias=attention_bias, position_ids=position_ids, drop_mask=response_mask, layer_past=layer_past, use_cache=use_cache)
 
 
 
 
 
 
 
2000
 
2001
  if attn_key_values is not None:
2002
  assert cache is not None