Vision-CAIR commited on
Commit
4e636d4
·
verified ·
1 Parent(s): 2202d00

Upload folder using huggingface_hub

Browse files
__init__.py CHANGED
@@ -9,15 +9,15 @@ import logging
9
  import torch
10
  from omegaconf import OmegaConf
11
 
12
- from minigpt4_video.registry import registry
13
- from minigpt4_video.base_model import BaseModel
14
- from minigpt4_video.base_processor import BaseProcessor
15
- from minigpt4_video.blip_processors import *
16
- from minigpt4_video.blip2 import Blip2Base
17
- from minigpt4_video.clip_vision_encoder import *
18
- from minigpt4_video.config import *
19
- from minigpt4_video.eva_vit import *
20
- from minigpt4_video.mini_gpt4_llama_v2 import MiniGPT4_Video
21
 
22
 
23
 
 
9
  import torch
10
  from omegaconf import OmegaConf
11
 
12
+ from .registry import registry
13
+ from .base_model import BaseModel
14
+ from .base_processor import BaseProcessor
15
+ from .blip_processors import *
16
+ from .blip2 import Blip2Base
17
+ from .clip_vision_encoder import *
18
+ from .config import *
19
+ from .eva_vit import *
20
+ from .mini_gpt4_llama_v2 import MiniGPT4_Video
21
 
22
 
23
 
__pycache__/base_model.cpython-310.pyc CHANGED
Binary files a/__pycache__/base_model.cpython-310.pyc and b/__pycache__/base_model.cpython-310.pyc differ
 
__pycache__/interfaces.cpython-310.pyc ADDED
Binary file (5.29 kB). View file
 
__pycache__/mini_gpt4_llama_v2.cpython-310.pyc CHANGED
Binary files a/__pycache__/mini_gpt4_llama_v2.cpython-310.pyc and b/__pycache__/mini_gpt4_llama_v2.cpython-310.pyc differ
 
__pycache__/modeling_llama_v2.cpython-310.pyc CHANGED
Binary files a/__pycache__/modeling_llama_v2.cpython-310.pyc and b/__pycache__/modeling_llama_v2.cpython-310.pyc differ
 
__pycache__/modeling_mistral.cpython-310.pyc ADDED
Binary file (39.2 kB). View file
 
__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/__pycache__/utils.cpython-310.pyc and b/__pycache__/utils.cpython-310.pyc differ
 
base_model.py CHANGED
@@ -11,8 +11,8 @@ import os
11
  import numpy as np
12
  import torch
13
  import torch.nn as nn
14
- from minigpt4_video.dist_utils import download_cached_file, is_dist_avail_and_initialized
15
- from minigpt4_video.utils import get_abs_path, is_url
16
  from omegaconf import OmegaConf
17
 
18
  from huggingface_hub import PyTorchModelHubMixin
 
11
  import numpy as np
12
  import torch
13
  import torch.nn as nn
14
+ from .dist_utils import download_cached_file, is_dist_avail_and_initialized
15
+ from .utils import get_abs_path, is_url
16
  from omegaconf import OmegaConf
17
 
18
  from huggingface_hub import PyTorchModelHubMixin
blip2.py CHANGED
@@ -15,13 +15,13 @@ import torch.nn as nn
15
  import torch.distributed as dist
16
  import torch.nn.functional as F
17
 
18
- from minigpt4_video import dist_utils as dist_utils
19
- from minigpt4_video.dist_utils import download_cached_file
20
- from minigpt4_video.utils import is_url
21
- from minigpt4_video.logger import MetricLogger
22
- from minigpt4_video.base_model import BaseModel
23
- from minigpt4_video.Qformer import BertConfig, BertLMHeadModel
24
- from minigpt4_video.eva_vit import create_eva_vit_g
25
  from transformers import BertTokenizer
26
 
27
 
 
15
  import torch.distributed as dist
16
  import torch.nn.functional as F
17
 
18
+ import dist_utils as dist_utils
19
+ from .dist_utils import download_cached_file
20
+ from .utils import is_url
21
+ from .logger import MetricLogger
22
+ from .base_model import BaseModel
23
+ from .Qformer import BertConfig, BertLMHeadModel
24
+ from .eva_vit import create_eva_vit_g
25
  from transformers import BertTokenizer
26
 
27
 
interfaces.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
2
+ Union, overload, runtime_checkable)
3
+
4
+ from typing_extensions import TypeGuard
5
+
6
+ from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
7
+ from vllm.logger import init_logger
8
+
9
+ logger = init_logger(__name__)
10
+
11
+
12
+ @runtime_checkable
13
+ class SupportsVision(Protocol):
14
+ """The interface required for all vision language models (VLMs)."""
15
+
16
+ supports_vision: ClassVar[Literal[True]] = True
17
+ """
18
+ A flag that indicates this model supports vision inputs.
19
+
20
+ Note:
21
+ There is no need to redefine this flag if this class is in the
22
+ MRO of your model class.
23
+ """
24
+
25
+ def __init__(self, *, multimodal_config: MultiModalConfig) -> None:
26
+ ...
27
+
28
+
29
+ # We can't use runtime_checkable with ClassVar for issubclass checks
30
+ # so we need to treat the class as an instance and use isinstance instead
31
+ @runtime_checkable
32
+ class _SupportsVisionType(Protocol):
33
+ supports_vision: Literal[True]
34
+
35
+ def __call__(self, *, multimodal_config: MultiModalConfig) -> None:
36
+ ...
37
+
38
+
39
+ @overload
40
+ def supports_vision(model: Type[object]) -> TypeGuard[Type[SupportsVision]]:
41
+ ...
42
+
43
+
44
+ @overload
45
+ def supports_vision(model: object) -> TypeGuard[SupportsVision]:
46
+ ...
47
+
48
+
49
+ def supports_vision(
50
+ model: Union[Type[object], object],
51
+ ) -> Union[TypeGuard[Type[SupportsVision]], TypeGuard[SupportsVision]]:
52
+ if isinstance(model, type):
53
+ return isinstance(model, _SupportsVisionType)
54
+
55
+ return isinstance(model, SupportsVision)
56
+
57
+
58
+ @runtime_checkable
59
+ class SupportsLoRA(Protocol):
60
+ """The interface required for all models that support LoRA."""
61
+
62
+ supports_lora: ClassVar[Literal[True]] = True
63
+ """
64
+ A flag that indicates this model supports LoRA.
65
+
66
+ Note:
67
+ There is no need to redefine this flag if this class is in the
68
+ MRO of your model class.
69
+ """
70
+
71
+ packed_modules_mapping: ClassVar[Dict[str, List[str]]]
72
+ supported_lora_modules: ClassVar[List[str]]
73
+ embedding_modules: ClassVar[Dict[str, str]]
74
+ embedding_padding_modules: ClassVar[List[str]]
75
+
76
+ # lora_config is None when LoRA is not enabled
77
+ def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
78
+ ...
79
+
80
+
81
+ # We can't use runtime_checkable with ClassVar for issubclass checks
82
+ # so we need to treat the class as an instance and use isinstance instead
83
+ @runtime_checkable
84
+ class _SupportsLoRAType(Protocol):
85
+ supports_lora: Literal[True]
86
+
87
+ packed_modules_mapping: Dict[str, List[str]]
88
+ supported_lora_modules: List[str]
89
+ embedding_modules: Dict[str, str]
90
+ embedding_padding_modules: List[str]
91
+
92
+ def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
93
+ ...
94
+
95
+
96
+ @overload
97
+ def supports_lora(model: Type[object]) -> TypeGuard[Type[SupportsLoRA]]:
98
+ ...
99
+
100
+
101
+ @overload
102
+ def supports_lora(model: object) -> TypeGuard[SupportsLoRA]:
103
+ ...
104
+
105
+
106
+ def supports_lora(
107
+ model: Union[Type[object], object],
108
+ ) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
109
+ result = _supports_lora(model)
110
+
111
+ if not result:
112
+ lora_attrs = (
113
+ "packed_modules_mapping",
114
+ "supported_lora_modules",
115
+ "embedding_modules",
116
+ "embedding_padding_modules",
117
+ )
118
+ missing_attrs = tuple(attr for attr in lora_attrs
119
+ if not hasattr(model, attr))
120
+
121
+ if getattr(model, "supports_lora", False):
122
+ if missing_attrs:
123
+ logger.warning(
124
+ "The model (%s) sets `supports_lora=True`, "
125
+ "but is missing LoRA-specific attributes: %s",
126
+ model,
127
+ missing_attrs,
128
+ )
129
+ else:
130
+ if not missing_attrs:
131
+ logger.warning(
132
+ "The model (%s) contains all LoRA-specific attributes, "
133
+ "but does not set `supports_lora=True`.", model)
134
+
135
+ return result
136
+
137
+
138
+ def _supports_lora(
139
+ model: Union[Type[object], object],
140
+ ) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
141
+ if isinstance(model, type):
142
+ return isinstance(model, _SupportsLoRAType)
143
+
144
+ return isinstance(model, SupportsLoRA)
145
+
146
+
147
+ @runtime_checkable
148
+ class HasInnerState(Protocol):
149
+ """The interface required for all models that has inner state."""
150
+
151
+ has_inner_state: ClassVar[Literal[True]] = True
152
+ """
153
+ A flag that indicates this model has inner state.
154
+ Models that has inner state usually need access to the scheduler_config
155
+ for max_num_seqs ,etc... (Currently only used by Jamba)
156
+ """
157
+
158
+ def __init__(self,
159
+ *,
160
+ scheduler_config: Optional[SchedulerConfig] = None) -> None:
161
+ ...
162
+
163
+
164
+ @runtime_checkable
165
+ class _HasInnerStateType(Protocol):
166
+ has_inner_state: ClassVar[Literal[True]]
167
+
168
+ def __init__(self,
169
+ *,
170
+ scheduler_config: Optional[SchedulerConfig] = None) -> None:
171
+ ...
172
+
173
+
174
+ @overload
175
+ def has_inner_state(model: object) -> TypeGuard[HasInnerState]:
176
+ ...
177
+
178
+
179
+ @overload
180
+ def has_inner_state(model: Type[object]) -> TypeGuard[Type[HasInnerState]]:
181
+ ...
182
+
183
+
184
+ def has_inner_state(
185
+ model: Union[Type[object], object]
186
+ ) -> Union[TypeGuard[Type[HasInnerState]], TypeGuard[HasInnerState]]:
187
+ if isinstance(model, type):
188
+ return isinstance(model, _HasInnerStateType)
189
+
190
+ return isinstance(model, HasInnerState)
logger.py CHANGED
@@ -13,7 +13,7 @@ from collections import defaultdict, deque
13
  import torch
14
  import torch.distributed as dist
15
 
16
- from minigpt4_video import dist_utils
17
 
18
 
19
  class SmoothedValue(object):
 
13
  import torch
14
  import torch.distributed as dist
15
 
16
+ import dist_utils
17
 
18
 
19
  class SmoothedValue(object):
mini_gpt4_llama_v2.py CHANGED
@@ -16,9 +16,9 @@ import torch
16
  from torch.cuda.amp import autocast as autocast
17
  import torch.nn as nn
18
 
19
- from minigpt4_video.registry import registry
20
- from minigpt4_video.blip2 import Blip2Base, disabled_train
21
- from minigpt4_video.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
22
  from transformers import LlamaTokenizer
23
  from transformers import BitsAndBytesConfig
24
  from transformers import AutoConfig, AutoTokenizer
@@ -34,7 +34,7 @@ import numpy as np
34
  import os
35
  from transformers import PretrainedConfig
36
  from transformers import PreTrainedModel
37
- from minigpt4_video.conversation import CONV_VISION
38
  import cv2
39
  def extract_audio(video_path, audio_path):
40
  video_clip = mp.VideoFileClip(video_path)
@@ -89,8 +89,10 @@ class MiniGPT4_Video(Blip2Base, PreTrainedModel):
89
  ):
90
  ## loop through the config minigpt4_video_config object and set the attributes
91
  # if isinstance(cfg, minigpt4_video_config):
92
- cfg = cfg.to_dict()
93
-
 
 
94
  for key, value in cfg.items():
95
  try:
96
  setattr(self, key, value)
@@ -216,8 +218,12 @@ class MiniGPT4_Video(Blip2Base, PreTrainedModel):
216
  else :
217
  # calculate the total number of frames in the video using opencv
218
  total_num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
219
- max_images_length = 45
220
- max_sub_len = 400
 
 
 
 
221
  images = []
222
  frame_count = 0
223
  sampling_interval = int(total_num_frames / max_images_length)
@@ -839,11 +845,11 @@ class MiniGPT4_Video(Blip2Base, PreTrainedModel):
839
  msg = model.load_state_dict(ckpt['model'], strict=False)
840
  # push the model to the hub with its metadata and config file
841
  model.to('cuda')
842
- model.push_to_hub("Vision-CAIR/MiniGPT4-video-hf")
843
  video_config = minigpt4_video_config(cfg)
844
  # video_config.save_pretrained("minigpt4_video_config")
845
  # print("Save Minigpt-4-LLM Config: minigpt4_video_config")
846
- video_config.push_to_hub("MiniGPT4-video")
847
  return model
848
 
849
 
 
16
  from torch.cuda.amp import autocast as autocast
17
  import torch.nn as nn
18
 
19
+ from .registry import registry
20
+ from .blip2 import Blip2Base, disabled_train
21
+ from .conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
22
  from transformers import LlamaTokenizer
23
  from transformers import BitsAndBytesConfig
24
  from transformers import AutoConfig, AutoTokenizer
 
34
  import os
35
  from transformers import PretrainedConfig
36
  from transformers import PreTrainedModel
37
+ from .conversation import CONV_VISION
38
  import cv2
39
  def extract_audio(video_path, audio_path):
40
  video_clip = mp.VideoFileClip(video_path)
 
89
  ):
90
  ## loop through the config minigpt4_video_config object and set the attributes
91
  # if isinstance(cfg, minigpt4_video_config):
92
+ try:
93
+ cfg = cfg.to_dict()
94
+ except:
95
+ pass
96
  for key, value in cfg.items():
97
  try:
98
  setattr(self, key, value)
 
218
  else :
219
  # calculate the total number of frames in the video using opencv
220
  total_num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
221
+ if self.model_type == "Mistral":
222
+ max_images_length = 90
223
+ max_sub_len = 800
224
+ else:
225
+ max_images_length = 45
226
+ max_sub_len = 400
227
  images = []
228
  frame_count = 0
229
  sampling_interval = int(total_num_frames / max_images_length)
 
845
  msg = model.load_state_dict(ckpt['model'], strict=False)
846
  # push the model to the hub with its metadata and config file
847
  model.to('cuda')
848
+ # model.push_to_hub("Vision-CAIR/MiniGPT4-video-mistral-hf")
849
  video_config = minigpt4_video_config(cfg)
850
  # video_config.save_pretrained("minigpt4_video_config")
851
  # print("Save Minigpt-4-LLM Config: minigpt4_video_config")
852
+ # video_config.push_to_hub("Vision-CAIR/MiniGPT4-video-mistral-hf")
853
  return model
854
 
855
 
utils.py CHANGED
@@ -23,7 +23,7 @@ import pandas as pd
23
  import yaml
24
  from iopath.common.download import download
25
  from iopath.common.file_io import file_lock, g_pathmgr
26
- from minigpt4_video.registry import registry
27
  from torch.utils.model_zoo import tqdm
28
  from torchvision.datasets.utils import (
29
  check_integrity,
@@ -422,3 +422,182 @@ def get_file_size(filename):
422
  """
423
  size_in_mb = os.path.getsize(filename) / float(1024**2)
424
  return size_in_mb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  import yaml
24
  from iopath.common.download import download
25
  from iopath.common.file_io import file_lock, g_pathmgr
26
+ from .registry import registry
27
  from torch.utils.model_zoo import tqdm
28
  from torchvision.datasets.utils import (
29
  check_integrity,
 
422
  """
423
  size_in_mb = os.path.getsize(filename) / float(1024**2)
424
  return size_in_mb
425
+
426
+ from typing import Dict, List, Protocol, Tuple
427
+
428
+ import torch
429
+ from torch.func import functional_call
430
+
431
+ from vllm.multimodal import BatchedTensors
432
+ from vllm.utils import is_pin_memory_available
433
+
434
+
435
+ def merge_vision_embeddings(input_ids: torch.Tensor,
436
+ inputs_embeds: torch.Tensor,
437
+ vision_embeddings: BatchedTensors,
438
+ image_token_id: int) -> torch.Tensor:
439
+ """
440
+ Merge `vision_embeddings` into `inputs_embeds` by overwriting the positions
441
+ in `inputs_embeds` corresponding to placeholder image tokens in `input_ids`.
442
+
443
+ Note:
444
+ This updates `inputs_embeds` in place.
445
+ """
446
+ mask = (input_ids == image_token_id)
447
+ num_expected_tokens = mask.sum()
448
+
449
+ if isinstance(vision_embeddings, torch.Tensor):
450
+ batch_size, batch_tokens, *_, embed_dim = vision_embeddings.shape
451
+ total_tokens = batch_size * batch_tokens
452
+ if num_expected_tokens != total_tokens:
453
+ expr = f"{batch_size} x {batch_tokens}"
454
+ raise ValueError(
455
+ f"Attempted to assign {expr} = {total_tokens} "
456
+ f"image tokens to {num_expected_tokens} placeholders")
457
+
458
+ inputs_embeds[mask] = vision_embeddings.view(total_tokens, embed_dim)
459
+ else:
460
+ size_per_batch = [t.shape[0] for t in vision_embeddings]
461
+ total_tokens = sum(size_per_batch)
462
+ if num_expected_tokens != total_tokens:
463
+ expr = ' + '.join(map(str, size_per_batch))
464
+ raise ValueError(
465
+ f"Attempted to assign {expr} = {total_tokens} "
466
+ f"image tokens to {num_expected_tokens} placeholders")
467
+
468
+ inputs_embeds[mask] = torch.cat(vision_embeddings)
469
+
470
+ return inputs_embeds
471
+
472
+
473
+ class LayerFn(Protocol):
474
+
475
+ def __call__(
476
+ self,
477
+ prefix="",
478
+ ) -> torch.nn.Module:
479
+ ...
480
+
481
+
482
+ class PPMissingLayer(torch.nn.Identity):
483
+ """
484
+ A placeholder layer for missing layers in a pipeline parallel model.
485
+ """
486
+
487
+ def __init__(self, *args, **kwargs):
488
+ super().__init__()
489
+
490
+
491
+ _CPU_OFFLOAD_BYTES = 0
492
+ _CPU_OFFLOAD_MAX_BYTES = 0
493
+
494
+
495
+ def set_cpu_offload_max_bytes(max_bytes: int) -> None:
496
+ global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
497
+ _CPU_OFFLOAD_BYTES = 0
498
+ _CPU_OFFLOAD_MAX_BYTES = max_bytes
499
+
500
+
501
+ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
502
+ device = next(module.parameters()).device
503
+
504
+ if device == torch.device("cpu"):
505
+ return module
506
+
507
+ global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
508
+ if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
509
+ return module
510
+
511
+ pin_memory = is_pin_memory_available()
512
+
513
+ # offload parameters to CPU
514
+ # use pin_memory if possible, which helps cudagraph capture speed
515
+ for p in module.parameters():
516
+ if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
517
+ # we use per-parameter offloading
518
+ # one module might have some parameters offloaded and some not
519
+ break
520
+
521
+ # `torch.empty_like` does not support `pin_memory` argument
522
+ cpu_data = torch.empty(size=p.data.size(),
523
+ dtype=p.data.dtype,
524
+ layout=p.data.layout,
525
+ device='cpu',
526
+ pin_memory=pin_memory)
527
+ cpu_data.copy_(p.data)
528
+ p.data = cpu_data
529
+ _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
530
+
531
+ state_dict: Dict[str, torch.Tensor] = module.state_dict()
532
+
533
+ original_forward = module.forward
534
+
535
+ def forward(*args, **kwargs):
536
+ module.forward = original_forward
537
+ device_state = {
538
+ # here we blindly call `to(device)`
539
+ # if the parameter is already on the device, it will be a no-op
540
+ k: v.to(device, non_blocking=True)
541
+ for k, v in state_dict.items()
542
+ }
543
+ output = functional_call(module,
544
+ device_state,
545
+ args=args,
546
+ kwargs=kwargs)
547
+ module.forward = forward
548
+ return output
549
+
550
+ module.forward = forward
551
+
552
+ return module
553
+
554
+
555
+ def make_layers(
556
+ num_hidden_layers: int,
557
+ layer_fn: LayerFn,
558
+ prefix: str,
559
+ ) -> Tuple[int, int, torch.nn.ModuleList]:
560
+ """Make a list of layers with the given layer function, taking
561
+ pipeline parallelism into account.
562
+ """
563
+ from vllm.distributed.parallel_state import get_pp_group
564
+ from vllm.distributed.utils import get_pp_indices
565
+ start_layer, end_layer = get_pp_indices(num_hidden_layers,
566
+ get_pp_group().rank_in_group,
567
+ get_pp_group().world_size)
568
+ modules = torch.nn.ModuleList(
569
+ [PPMissingLayer() for _ in range(start_layer)] + [
570
+ maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
571
+ for idx in range(start_layer, end_layer)
572
+ ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
573
+ return start_layer, end_layer, modules
574
+
575
+
576
+ # NOTE: don't use lru_cache here because it can prevent garbage collection
577
+ _model_to_pp_missing_layer_names: Dict[int, List[str]] = {}
578
+
579
+
580
+ def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
581
+ """Get the names of the missing layers in a pipeline parallel model."""
582
+ model_id = id(model)
583
+ if model_id in _model_to_pp_missing_layer_names:
584
+ return _model_to_pp_missing_layer_names[model_id]
585
+
586
+ missing_layer_names = []
587
+ for name, module in model.named_modules():
588
+ if isinstance(module, PPMissingLayer):
589
+ # NOTE: the trailing dot is used to match the prefix of the layer.
590
+ # without the dot, we could match a layer that is not missing,
591
+ # e.g., 'encoder.layer.1' would match 'encoder.layer.11'
592
+ missing_layer_names.append(name + '.')
593
+ _model_to_pp_missing_layer_names[model_id] = missing_layer_names
594
+
595
+ return missing_layer_names
596
+
597
+
598
+ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
599
+ """Check if a parameter is missing in a pipeline parallel model."""
600
+ for missing_layer_name in get_pp_missing_layer_names(model):
601
+ if name.startswith(missing_layer_name):
602
+ return True
603
+ return False