File size: 5,418 Bytes
4e636d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
Union, overload, runtime_checkable)
from typing_extensions import TypeGuard
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
@runtime_checkable
class SupportsVision(Protocol):
"""The interface required for all vision language models (VLMs)."""
supports_vision: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports vision inputs.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
def __init__(self, *, multimodal_config: MultiModalConfig) -> None:
...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsVisionType(Protocol):
supports_vision: Literal[True]
def __call__(self, *, multimodal_config: MultiModalConfig) -> None:
...
@overload
def supports_vision(model: Type[object]) -> TypeGuard[Type[SupportsVision]]:
...
@overload
def supports_vision(model: object) -> TypeGuard[SupportsVision]:
...
def supports_vision(
model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsVision]], TypeGuard[SupportsVision]]:
if isinstance(model, type):
return isinstance(model, _SupportsVisionType)
return isinstance(model, SupportsVision)
@runtime_checkable
class SupportsLoRA(Protocol):
"""The interface required for all models that support LoRA."""
supports_lora: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports LoRA.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
packed_modules_mapping: ClassVar[Dict[str, List[str]]]
supported_lora_modules: ClassVar[List[str]]
embedding_modules: ClassVar[Dict[str, str]]
embedding_padding_modules: ClassVar[List[str]]
# lora_config is None when LoRA is not enabled
def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsLoRAType(Protocol):
supports_lora: Literal[True]
packed_modules_mapping: Dict[str, List[str]]
supported_lora_modules: List[str]
embedding_modules: Dict[str, str]
embedding_padding_modules: List[str]
def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
...
@overload
def supports_lora(model: Type[object]) -> TypeGuard[Type[SupportsLoRA]]:
...
@overload
def supports_lora(model: object) -> TypeGuard[SupportsLoRA]:
...
def supports_lora(
model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
result = _supports_lora(model)
if not result:
lora_attrs = (
"packed_modules_mapping",
"supported_lora_modules",
"embedding_modules",
"embedding_padding_modules",
)
missing_attrs = tuple(attr for attr in lora_attrs
if not hasattr(model, attr))
if getattr(model, "supports_lora", False):
if missing_attrs:
logger.warning(
"The model (%s) sets `supports_lora=True`, "
"but is missing LoRA-specific attributes: %s",
model,
missing_attrs,
)
else:
if not missing_attrs:
logger.warning(
"The model (%s) contains all LoRA-specific attributes, "
"but does not set `supports_lora=True`.", model)
return result
def _supports_lora(
model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
if isinstance(model, type):
return isinstance(model, _SupportsLoRAType)
return isinstance(model, SupportsLoRA)
@runtime_checkable
class HasInnerState(Protocol):
"""The interface required for all models that has inner state."""
has_inner_state: ClassVar[Literal[True]] = True
"""
A flag that indicates this model has inner state.
Models that has inner state usually need access to the scheduler_config
for max_num_seqs ,etc... (Currently only used by Jamba)
"""
def __init__(self,
*,
scheduler_config: Optional[SchedulerConfig] = None) -> None:
...
@runtime_checkable
class _HasInnerStateType(Protocol):
has_inner_state: ClassVar[Literal[True]]
def __init__(self,
*,
scheduler_config: Optional[SchedulerConfig] = None) -> None:
...
@overload
def has_inner_state(model: object) -> TypeGuard[HasInnerState]:
...
@overload
def has_inner_state(model: Type[object]) -> TypeGuard[Type[HasInnerState]]:
...
def has_inner_state(
model: Union[Type[object], object]
) -> Union[TypeGuard[Type[HasInnerState]], TypeGuard[HasInnerState]]:
if isinstance(model, type):
return isinstance(model, _HasInnerStateType)
return isinstance(model, HasInnerState) |