Vision-CAIR
commited on
Upload folder using huggingface_hub
Browse files- __init__.py +9 -9
- __pycache__/base_model.cpython-310.pyc +0 -0
- __pycache__/interfaces.cpython-310.pyc +0 -0
- __pycache__/mini_gpt4_llama_v2.cpython-310.pyc +0 -0
- __pycache__/modeling_llama_v2.cpython-310.pyc +0 -0
- __pycache__/modeling_mistral.cpython-310.pyc +0 -0
- __pycache__/utils.cpython-310.pyc +0 -0
- base_model.py +2 -2
- blip2.py +7 -7
- interfaces.py +190 -0
- logger.py +1 -1
- mini_gpt4_llama_v2.py +16 -10
- utils.py +180 -1
__init__.py
CHANGED
@@ -9,15 +9,15 @@ import logging
|
|
9 |
import torch
|
10 |
from omegaconf import OmegaConf
|
11 |
|
12 |
-
from
|
13 |
-
from
|
14 |
-
from
|
15 |
-
from
|
16 |
-
from
|
17 |
-
from
|
18 |
-
from
|
19 |
-
from
|
20 |
-
from
|
21 |
|
22 |
|
23 |
|
|
|
9 |
import torch
|
10 |
from omegaconf import OmegaConf
|
11 |
|
12 |
+
from .registry import registry
|
13 |
+
from .base_model import BaseModel
|
14 |
+
from .base_processor import BaseProcessor
|
15 |
+
from .blip_processors import *
|
16 |
+
from .blip2 import Blip2Base
|
17 |
+
from .clip_vision_encoder import *
|
18 |
+
from .config import *
|
19 |
+
from .eva_vit import *
|
20 |
+
from .mini_gpt4_llama_v2 import MiniGPT4_Video
|
21 |
|
22 |
|
23 |
|
__pycache__/base_model.cpython-310.pyc
CHANGED
Binary files a/__pycache__/base_model.cpython-310.pyc and b/__pycache__/base_model.cpython-310.pyc differ
|
|
__pycache__/interfaces.cpython-310.pyc
ADDED
Binary file (5.29 kB). View file
|
|
__pycache__/mini_gpt4_llama_v2.cpython-310.pyc
CHANGED
Binary files a/__pycache__/mini_gpt4_llama_v2.cpython-310.pyc and b/__pycache__/mini_gpt4_llama_v2.cpython-310.pyc differ
|
|
__pycache__/modeling_llama_v2.cpython-310.pyc
CHANGED
Binary files a/__pycache__/modeling_llama_v2.cpython-310.pyc and b/__pycache__/modeling_llama_v2.cpython-310.pyc differ
|
|
__pycache__/modeling_mistral.cpython-310.pyc
ADDED
Binary file (39.2 kB). View file
|
|
__pycache__/utils.cpython-310.pyc
CHANGED
Binary files a/__pycache__/utils.cpython-310.pyc and b/__pycache__/utils.cpython-310.pyc differ
|
|
base_model.py
CHANGED
@@ -11,8 +11,8 @@ import os
|
|
11 |
import numpy as np
|
12 |
import torch
|
13 |
import torch.nn as nn
|
14 |
-
from
|
15 |
-
from
|
16 |
from omegaconf import OmegaConf
|
17 |
|
18 |
from huggingface_hub import PyTorchModelHubMixin
|
|
|
11 |
import numpy as np
|
12 |
import torch
|
13 |
import torch.nn as nn
|
14 |
+
from .dist_utils import download_cached_file, is_dist_avail_and_initialized
|
15 |
+
from .utils import get_abs_path, is_url
|
16 |
from omegaconf import OmegaConf
|
17 |
|
18 |
from huggingface_hub import PyTorchModelHubMixin
|
blip2.py
CHANGED
@@ -15,13 +15,13 @@ import torch.nn as nn
|
|
15 |
import torch.distributed as dist
|
16 |
import torch.nn.functional as F
|
17 |
|
18 |
-
|
19 |
-
from
|
20 |
-
from
|
21 |
-
from
|
22 |
-
from
|
23 |
-
from
|
24 |
-
from
|
25 |
from transformers import BertTokenizer
|
26 |
|
27 |
|
|
|
15 |
import torch.distributed as dist
|
16 |
import torch.nn.functional as F
|
17 |
|
18 |
+
import dist_utils as dist_utils
|
19 |
+
from .dist_utils import download_cached_file
|
20 |
+
from .utils import is_url
|
21 |
+
from .logger import MetricLogger
|
22 |
+
from .base_model import BaseModel
|
23 |
+
from .Qformer import BertConfig, BertLMHeadModel
|
24 |
+
from .eva_vit import create_eva_vit_g
|
25 |
from transformers import BertTokenizer
|
26 |
|
27 |
|
interfaces.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
|
2 |
+
Union, overload, runtime_checkable)
|
3 |
+
|
4 |
+
from typing_extensions import TypeGuard
|
5 |
+
|
6 |
+
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
|
7 |
+
from vllm.logger import init_logger
|
8 |
+
|
9 |
+
logger = init_logger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
@runtime_checkable
|
13 |
+
class SupportsVision(Protocol):
|
14 |
+
"""The interface required for all vision language models (VLMs)."""
|
15 |
+
|
16 |
+
supports_vision: ClassVar[Literal[True]] = True
|
17 |
+
"""
|
18 |
+
A flag that indicates this model supports vision inputs.
|
19 |
+
|
20 |
+
Note:
|
21 |
+
There is no need to redefine this flag if this class is in the
|
22 |
+
MRO of your model class.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, *, multimodal_config: MultiModalConfig) -> None:
|
26 |
+
...
|
27 |
+
|
28 |
+
|
29 |
+
# We can't use runtime_checkable with ClassVar for issubclass checks
|
30 |
+
# so we need to treat the class as an instance and use isinstance instead
|
31 |
+
@runtime_checkable
|
32 |
+
class _SupportsVisionType(Protocol):
|
33 |
+
supports_vision: Literal[True]
|
34 |
+
|
35 |
+
def __call__(self, *, multimodal_config: MultiModalConfig) -> None:
|
36 |
+
...
|
37 |
+
|
38 |
+
|
39 |
+
@overload
|
40 |
+
def supports_vision(model: Type[object]) -> TypeGuard[Type[SupportsVision]]:
|
41 |
+
...
|
42 |
+
|
43 |
+
|
44 |
+
@overload
|
45 |
+
def supports_vision(model: object) -> TypeGuard[SupportsVision]:
|
46 |
+
...
|
47 |
+
|
48 |
+
|
49 |
+
def supports_vision(
|
50 |
+
model: Union[Type[object], object],
|
51 |
+
) -> Union[TypeGuard[Type[SupportsVision]], TypeGuard[SupportsVision]]:
|
52 |
+
if isinstance(model, type):
|
53 |
+
return isinstance(model, _SupportsVisionType)
|
54 |
+
|
55 |
+
return isinstance(model, SupportsVision)
|
56 |
+
|
57 |
+
|
58 |
+
@runtime_checkable
|
59 |
+
class SupportsLoRA(Protocol):
|
60 |
+
"""The interface required for all models that support LoRA."""
|
61 |
+
|
62 |
+
supports_lora: ClassVar[Literal[True]] = True
|
63 |
+
"""
|
64 |
+
A flag that indicates this model supports LoRA.
|
65 |
+
|
66 |
+
Note:
|
67 |
+
There is no need to redefine this flag if this class is in the
|
68 |
+
MRO of your model class.
|
69 |
+
"""
|
70 |
+
|
71 |
+
packed_modules_mapping: ClassVar[Dict[str, List[str]]]
|
72 |
+
supported_lora_modules: ClassVar[List[str]]
|
73 |
+
embedding_modules: ClassVar[Dict[str, str]]
|
74 |
+
embedding_padding_modules: ClassVar[List[str]]
|
75 |
+
|
76 |
+
# lora_config is None when LoRA is not enabled
|
77 |
+
def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
|
78 |
+
...
|
79 |
+
|
80 |
+
|
81 |
+
# We can't use runtime_checkable with ClassVar for issubclass checks
|
82 |
+
# so we need to treat the class as an instance and use isinstance instead
|
83 |
+
@runtime_checkable
|
84 |
+
class _SupportsLoRAType(Protocol):
|
85 |
+
supports_lora: Literal[True]
|
86 |
+
|
87 |
+
packed_modules_mapping: Dict[str, List[str]]
|
88 |
+
supported_lora_modules: List[str]
|
89 |
+
embedding_modules: Dict[str, str]
|
90 |
+
embedding_padding_modules: List[str]
|
91 |
+
|
92 |
+
def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
|
93 |
+
...
|
94 |
+
|
95 |
+
|
96 |
+
@overload
|
97 |
+
def supports_lora(model: Type[object]) -> TypeGuard[Type[SupportsLoRA]]:
|
98 |
+
...
|
99 |
+
|
100 |
+
|
101 |
+
@overload
|
102 |
+
def supports_lora(model: object) -> TypeGuard[SupportsLoRA]:
|
103 |
+
...
|
104 |
+
|
105 |
+
|
106 |
+
def supports_lora(
|
107 |
+
model: Union[Type[object], object],
|
108 |
+
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
|
109 |
+
result = _supports_lora(model)
|
110 |
+
|
111 |
+
if not result:
|
112 |
+
lora_attrs = (
|
113 |
+
"packed_modules_mapping",
|
114 |
+
"supported_lora_modules",
|
115 |
+
"embedding_modules",
|
116 |
+
"embedding_padding_modules",
|
117 |
+
)
|
118 |
+
missing_attrs = tuple(attr for attr in lora_attrs
|
119 |
+
if not hasattr(model, attr))
|
120 |
+
|
121 |
+
if getattr(model, "supports_lora", False):
|
122 |
+
if missing_attrs:
|
123 |
+
logger.warning(
|
124 |
+
"The model (%s) sets `supports_lora=True`, "
|
125 |
+
"but is missing LoRA-specific attributes: %s",
|
126 |
+
model,
|
127 |
+
missing_attrs,
|
128 |
+
)
|
129 |
+
else:
|
130 |
+
if not missing_attrs:
|
131 |
+
logger.warning(
|
132 |
+
"The model (%s) contains all LoRA-specific attributes, "
|
133 |
+
"but does not set `supports_lora=True`.", model)
|
134 |
+
|
135 |
+
return result
|
136 |
+
|
137 |
+
|
138 |
+
def _supports_lora(
|
139 |
+
model: Union[Type[object], object],
|
140 |
+
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
|
141 |
+
if isinstance(model, type):
|
142 |
+
return isinstance(model, _SupportsLoRAType)
|
143 |
+
|
144 |
+
return isinstance(model, SupportsLoRA)
|
145 |
+
|
146 |
+
|
147 |
+
@runtime_checkable
|
148 |
+
class HasInnerState(Protocol):
|
149 |
+
"""The interface required for all models that has inner state."""
|
150 |
+
|
151 |
+
has_inner_state: ClassVar[Literal[True]] = True
|
152 |
+
"""
|
153 |
+
A flag that indicates this model has inner state.
|
154 |
+
Models that has inner state usually need access to the scheduler_config
|
155 |
+
for max_num_seqs ,etc... (Currently only used by Jamba)
|
156 |
+
"""
|
157 |
+
|
158 |
+
def __init__(self,
|
159 |
+
*,
|
160 |
+
scheduler_config: Optional[SchedulerConfig] = None) -> None:
|
161 |
+
...
|
162 |
+
|
163 |
+
|
164 |
+
@runtime_checkable
|
165 |
+
class _HasInnerStateType(Protocol):
|
166 |
+
has_inner_state: ClassVar[Literal[True]]
|
167 |
+
|
168 |
+
def __init__(self,
|
169 |
+
*,
|
170 |
+
scheduler_config: Optional[SchedulerConfig] = None) -> None:
|
171 |
+
...
|
172 |
+
|
173 |
+
|
174 |
+
@overload
|
175 |
+
def has_inner_state(model: object) -> TypeGuard[HasInnerState]:
|
176 |
+
...
|
177 |
+
|
178 |
+
|
179 |
+
@overload
|
180 |
+
def has_inner_state(model: Type[object]) -> TypeGuard[Type[HasInnerState]]:
|
181 |
+
...
|
182 |
+
|
183 |
+
|
184 |
+
def has_inner_state(
|
185 |
+
model: Union[Type[object], object]
|
186 |
+
) -> Union[TypeGuard[Type[HasInnerState]], TypeGuard[HasInnerState]]:
|
187 |
+
if isinstance(model, type):
|
188 |
+
return isinstance(model, _HasInnerStateType)
|
189 |
+
|
190 |
+
return isinstance(model, HasInnerState)
|
logger.py
CHANGED
@@ -13,7 +13,7 @@ from collections import defaultdict, deque
|
|
13 |
import torch
|
14 |
import torch.distributed as dist
|
15 |
|
16 |
-
|
17 |
|
18 |
|
19 |
class SmoothedValue(object):
|
|
|
13 |
import torch
|
14 |
import torch.distributed as dist
|
15 |
|
16 |
+
import dist_utils
|
17 |
|
18 |
|
19 |
class SmoothedValue(object):
|
mini_gpt4_llama_v2.py
CHANGED
@@ -16,9 +16,9 @@ import torch
|
|
16 |
from torch.cuda.amp import autocast as autocast
|
17 |
import torch.nn as nn
|
18 |
|
19 |
-
from
|
20 |
-
from
|
21 |
-
from
|
22 |
from transformers import LlamaTokenizer
|
23 |
from transformers import BitsAndBytesConfig
|
24 |
from transformers import AutoConfig, AutoTokenizer
|
@@ -34,7 +34,7 @@ import numpy as np
|
|
34 |
import os
|
35 |
from transformers import PretrainedConfig
|
36 |
from transformers import PreTrainedModel
|
37 |
-
from
|
38 |
import cv2
|
39 |
def extract_audio(video_path, audio_path):
|
40 |
video_clip = mp.VideoFileClip(video_path)
|
@@ -89,8 +89,10 @@ class MiniGPT4_Video(Blip2Base, PreTrainedModel):
|
|
89 |
):
|
90 |
## loop through the config minigpt4_video_config object and set the attributes
|
91 |
# if isinstance(cfg, minigpt4_video_config):
|
92 |
-
|
93 |
-
|
|
|
|
|
94 |
for key, value in cfg.items():
|
95 |
try:
|
96 |
setattr(self, key, value)
|
@@ -216,8 +218,12 @@ class MiniGPT4_Video(Blip2Base, PreTrainedModel):
|
|
216 |
else :
|
217 |
# calculate the total number of frames in the video using opencv
|
218 |
total_num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
219 |
-
|
220 |
-
|
|
|
|
|
|
|
|
|
221 |
images = []
|
222 |
frame_count = 0
|
223 |
sampling_interval = int(total_num_frames / max_images_length)
|
@@ -839,11 +845,11 @@ class MiniGPT4_Video(Blip2Base, PreTrainedModel):
|
|
839 |
msg = model.load_state_dict(ckpt['model'], strict=False)
|
840 |
# push the model to the hub with its metadata and config file
|
841 |
model.to('cuda')
|
842 |
-
model.push_to_hub("Vision-CAIR/MiniGPT4-video-hf")
|
843 |
video_config = minigpt4_video_config(cfg)
|
844 |
# video_config.save_pretrained("minigpt4_video_config")
|
845 |
# print("Save Minigpt-4-LLM Config: minigpt4_video_config")
|
846 |
-
video_config.push_to_hub("MiniGPT4-video")
|
847 |
return model
|
848 |
|
849 |
|
|
|
16 |
from torch.cuda.amp import autocast as autocast
|
17 |
import torch.nn as nn
|
18 |
|
19 |
+
from .registry import registry
|
20 |
+
from .blip2 import Blip2Base, disabled_train
|
21 |
+
from .conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
|
22 |
from transformers import LlamaTokenizer
|
23 |
from transformers import BitsAndBytesConfig
|
24 |
from transformers import AutoConfig, AutoTokenizer
|
|
|
34 |
import os
|
35 |
from transformers import PretrainedConfig
|
36 |
from transformers import PreTrainedModel
|
37 |
+
from .conversation import CONV_VISION
|
38 |
import cv2
|
39 |
def extract_audio(video_path, audio_path):
|
40 |
video_clip = mp.VideoFileClip(video_path)
|
|
|
89 |
):
|
90 |
## loop through the config minigpt4_video_config object and set the attributes
|
91 |
# if isinstance(cfg, minigpt4_video_config):
|
92 |
+
try:
|
93 |
+
cfg = cfg.to_dict()
|
94 |
+
except:
|
95 |
+
pass
|
96 |
for key, value in cfg.items():
|
97 |
try:
|
98 |
setattr(self, key, value)
|
|
|
218 |
else :
|
219 |
# calculate the total number of frames in the video using opencv
|
220 |
total_num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
221 |
+
if self.model_type == "Mistral":
|
222 |
+
max_images_length = 90
|
223 |
+
max_sub_len = 800
|
224 |
+
else:
|
225 |
+
max_images_length = 45
|
226 |
+
max_sub_len = 400
|
227 |
images = []
|
228 |
frame_count = 0
|
229 |
sampling_interval = int(total_num_frames / max_images_length)
|
|
|
845 |
msg = model.load_state_dict(ckpt['model'], strict=False)
|
846 |
# push the model to the hub with its metadata and config file
|
847 |
model.to('cuda')
|
848 |
+
# model.push_to_hub("Vision-CAIR/MiniGPT4-video-mistral-hf")
|
849 |
video_config = minigpt4_video_config(cfg)
|
850 |
# video_config.save_pretrained("minigpt4_video_config")
|
851 |
# print("Save Minigpt-4-LLM Config: minigpt4_video_config")
|
852 |
+
# video_config.push_to_hub("Vision-CAIR/MiniGPT4-video-mistral-hf")
|
853 |
return model
|
854 |
|
855 |
|
utils.py
CHANGED
@@ -23,7 +23,7 @@ import pandas as pd
|
|
23 |
import yaml
|
24 |
from iopath.common.download import download
|
25 |
from iopath.common.file_io import file_lock, g_pathmgr
|
26 |
-
from
|
27 |
from torch.utils.model_zoo import tqdm
|
28 |
from torchvision.datasets.utils import (
|
29 |
check_integrity,
|
@@ -422,3 +422,182 @@ def get_file_size(filename):
|
|
422 |
"""
|
423 |
size_in_mb = os.path.getsize(filename) / float(1024**2)
|
424 |
return size_in_mb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
import yaml
|
24 |
from iopath.common.download import download
|
25 |
from iopath.common.file_io import file_lock, g_pathmgr
|
26 |
+
from .registry import registry
|
27 |
from torch.utils.model_zoo import tqdm
|
28 |
from torchvision.datasets.utils import (
|
29 |
check_integrity,
|
|
|
422 |
"""
|
423 |
size_in_mb = os.path.getsize(filename) / float(1024**2)
|
424 |
return size_in_mb
|
425 |
+
|
426 |
+
from typing import Dict, List, Protocol, Tuple
|
427 |
+
|
428 |
+
import torch
|
429 |
+
from torch.func import functional_call
|
430 |
+
|
431 |
+
from vllm.multimodal import BatchedTensors
|
432 |
+
from vllm.utils import is_pin_memory_available
|
433 |
+
|
434 |
+
|
435 |
+
def merge_vision_embeddings(input_ids: torch.Tensor,
|
436 |
+
inputs_embeds: torch.Tensor,
|
437 |
+
vision_embeddings: BatchedTensors,
|
438 |
+
image_token_id: int) -> torch.Tensor:
|
439 |
+
"""
|
440 |
+
Merge `vision_embeddings` into `inputs_embeds` by overwriting the positions
|
441 |
+
in `inputs_embeds` corresponding to placeholder image tokens in `input_ids`.
|
442 |
+
|
443 |
+
Note:
|
444 |
+
This updates `inputs_embeds` in place.
|
445 |
+
"""
|
446 |
+
mask = (input_ids == image_token_id)
|
447 |
+
num_expected_tokens = mask.sum()
|
448 |
+
|
449 |
+
if isinstance(vision_embeddings, torch.Tensor):
|
450 |
+
batch_size, batch_tokens, *_, embed_dim = vision_embeddings.shape
|
451 |
+
total_tokens = batch_size * batch_tokens
|
452 |
+
if num_expected_tokens != total_tokens:
|
453 |
+
expr = f"{batch_size} x {batch_tokens}"
|
454 |
+
raise ValueError(
|
455 |
+
f"Attempted to assign {expr} = {total_tokens} "
|
456 |
+
f"image tokens to {num_expected_tokens} placeholders")
|
457 |
+
|
458 |
+
inputs_embeds[mask] = vision_embeddings.view(total_tokens, embed_dim)
|
459 |
+
else:
|
460 |
+
size_per_batch = [t.shape[0] for t in vision_embeddings]
|
461 |
+
total_tokens = sum(size_per_batch)
|
462 |
+
if num_expected_tokens != total_tokens:
|
463 |
+
expr = ' + '.join(map(str, size_per_batch))
|
464 |
+
raise ValueError(
|
465 |
+
f"Attempted to assign {expr} = {total_tokens} "
|
466 |
+
f"image tokens to {num_expected_tokens} placeholders")
|
467 |
+
|
468 |
+
inputs_embeds[mask] = torch.cat(vision_embeddings)
|
469 |
+
|
470 |
+
return inputs_embeds
|
471 |
+
|
472 |
+
|
473 |
+
class LayerFn(Protocol):
|
474 |
+
|
475 |
+
def __call__(
|
476 |
+
self,
|
477 |
+
prefix="",
|
478 |
+
) -> torch.nn.Module:
|
479 |
+
...
|
480 |
+
|
481 |
+
|
482 |
+
class PPMissingLayer(torch.nn.Identity):
|
483 |
+
"""
|
484 |
+
A placeholder layer for missing layers in a pipeline parallel model.
|
485 |
+
"""
|
486 |
+
|
487 |
+
def __init__(self, *args, **kwargs):
|
488 |
+
super().__init__()
|
489 |
+
|
490 |
+
|
491 |
+
_CPU_OFFLOAD_BYTES = 0
|
492 |
+
_CPU_OFFLOAD_MAX_BYTES = 0
|
493 |
+
|
494 |
+
|
495 |
+
def set_cpu_offload_max_bytes(max_bytes: int) -> None:
|
496 |
+
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
|
497 |
+
_CPU_OFFLOAD_BYTES = 0
|
498 |
+
_CPU_OFFLOAD_MAX_BYTES = max_bytes
|
499 |
+
|
500 |
+
|
501 |
+
def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
|
502 |
+
device = next(module.parameters()).device
|
503 |
+
|
504 |
+
if device == torch.device("cpu"):
|
505 |
+
return module
|
506 |
+
|
507 |
+
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
|
508 |
+
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
|
509 |
+
return module
|
510 |
+
|
511 |
+
pin_memory = is_pin_memory_available()
|
512 |
+
|
513 |
+
# offload parameters to CPU
|
514 |
+
# use pin_memory if possible, which helps cudagraph capture speed
|
515 |
+
for p in module.parameters():
|
516 |
+
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
|
517 |
+
# we use per-parameter offloading
|
518 |
+
# one module might have some parameters offloaded and some not
|
519 |
+
break
|
520 |
+
|
521 |
+
# `torch.empty_like` does not support `pin_memory` argument
|
522 |
+
cpu_data = torch.empty(size=p.data.size(),
|
523 |
+
dtype=p.data.dtype,
|
524 |
+
layout=p.data.layout,
|
525 |
+
device='cpu',
|
526 |
+
pin_memory=pin_memory)
|
527 |
+
cpu_data.copy_(p.data)
|
528 |
+
p.data = cpu_data
|
529 |
+
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
|
530 |
+
|
531 |
+
state_dict: Dict[str, torch.Tensor] = module.state_dict()
|
532 |
+
|
533 |
+
original_forward = module.forward
|
534 |
+
|
535 |
+
def forward(*args, **kwargs):
|
536 |
+
module.forward = original_forward
|
537 |
+
device_state = {
|
538 |
+
# here we blindly call `to(device)`
|
539 |
+
# if the parameter is already on the device, it will be a no-op
|
540 |
+
k: v.to(device, non_blocking=True)
|
541 |
+
for k, v in state_dict.items()
|
542 |
+
}
|
543 |
+
output = functional_call(module,
|
544 |
+
device_state,
|
545 |
+
args=args,
|
546 |
+
kwargs=kwargs)
|
547 |
+
module.forward = forward
|
548 |
+
return output
|
549 |
+
|
550 |
+
module.forward = forward
|
551 |
+
|
552 |
+
return module
|
553 |
+
|
554 |
+
|
555 |
+
def make_layers(
|
556 |
+
num_hidden_layers: int,
|
557 |
+
layer_fn: LayerFn,
|
558 |
+
prefix: str,
|
559 |
+
) -> Tuple[int, int, torch.nn.ModuleList]:
|
560 |
+
"""Make a list of layers with the given layer function, taking
|
561 |
+
pipeline parallelism into account.
|
562 |
+
"""
|
563 |
+
from vllm.distributed.parallel_state import get_pp_group
|
564 |
+
from vllm.distributed.utils import get_pp_indices
|
565 |
+
start_layer, end_layer = get_pp_indices(num_hidden_layers,
|
566 |
+
get_pp_group().rank_in_group,
|
567 |
+
get_pp_group().world_size)
|
568 |
+
modules = torch.nn.ModuleList(
|
569 |
+
[PPMissingLayer() for _ in range(start_layer)] + [
|
570 |
+
maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
|
571 |
+
for idx in range(start_layer, end_layer)
|
572 |
+
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
|
573 |
+
return start_layer, end_layer, modules
|
574 |
+
|
575 |
+
|
576 |
+
# NOTE: don't use lru_cache here because it can prevent garbage collection
|
577 |
+
_model_to_pp_missing_layer_names: Dict[int, List[str]] = {}
|
578 |
+
|
579 |
+
|
580 |
+
def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
|
581 |
+
"""Get the names of the missing layers in a pipeline parallel model."""
|
582 |
+
model_id = id(model)
|
583 |
+
if model_id in _model_to_pp_missing_layer_names:
|
584 |
+
return _model_to_pp_missing_layer_names[model_id]
|
585 |
+
|
586 |
+
missing_layer_names = []
|
587 |
+
for name, module in model.named_modules():
|
588 |
+
if isinstance(module, PPMissingLayer):
|
589 |
+
# NOTE: the trailing dot is used to match the prefix of the layer.
|
590 |
+
# without the dot, we could match a layer that is not missing,
|
591 |
+
# e.g., 'encoder.layer.1' would match 'encoder.layer.11'
|
592 |
+
missing_layer_names.append(name + '.')
|
593 |
+
_model_to_pp_missing_layer_names[model_id] = missing_layer_names
|
594 |
+
|
595 |
+
return missing_layer_names
|
596 |
+
|
597 |
+
|
598 |
+
def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
|
599 |
+
"""Check if a parameter is missing in a pipeline parallel model."""
|
600 |
+
for missing_layer_name in get_pp_missing_layer_names(model):
|
601 |
+
if name.startswith(missing_layer_name):
|
602 |
+
return True
|
603 |
+
return False
|