|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
import warnings |
|
from typing import Any, Optional, Union |
|
|
|
from torch import nn |
|
from tqdm import tqdm |
|
|
|
from peft_mora.tuners import adalora, loha, lokr, lora, oft |
|
from peft_mora.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists |
|
from peft_mora.utils import ( |
|
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, |
|
ModulesToSaveWrapper, |
|
PeftType, |
|
_get_submodules, |
|
get_auto_gptq_quant_linear, |
|
) |
|
|
|
|
|
|
|
COMPATIBLE_TUNER_TYPES = (PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.ADALORA, PeftType.OFT) |
|
PREFIXES = [lora.LoraModel.prefix, lokr.LoKrModel.prefix, loha.LoHaModel.prefix, oft.OFTModel.prefix] |
|
Configs = Union[lora.LoraConfig, loha.LoHaConfig, lokr.LoKrConfig, adalora.AdaLoraConfig, oft.OFTConfig] |
|
Layers = (lora.layer.LoraLayer, loha.layer.LoHaLayer, lokr.layer.LoKrLayer, adalora.layer.AdaLoraLayer, oft.OFTLayer) |
|
|
|
|
|
class MixedModel(BaseTuner): |
|
""" |
|
A class that allows to mix different types of adapters in a single model. |
|
|
|
Note: This class should usually not be initialized directly. Instead, use `get_peft_model` with the argument |
|
`mixed=True`. |
|
|
|
Args: |
|
model (:obj:`nn.Module`): |
|
The model to be tuned. |
|
config (:obj:`PeftConfig`): |
|
The config of the model to be tuned. The adapter type must be compatible. |
|
adapter_name (:obj:`str`): |
|
The name of the first adapter. |
|
""" |
|
|
|
def __init__(self, model: nn.Module, config: Configs, adapter_name: str) -> None: |
|
super().__init__(model, config, adapter_name) |
|
|
|
def _check_new_adapter_config(self, config: Configs) -> None: |
|
""" |
|
A helper method to check the config when a new adapter is being added. |
|
|
|
Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters. |
|
|
|
""" |
|
if not isinstance(config, Configs.__args__): |
|
raise ValueError( |
|
f"{self.__class__.__name__} only supports {COMPATIBLE_TUNER_TYPES} configs, but got {type(config)}." |
|
) |
|
|
|
biases = (getattr(config, "bias", None) for config in self.peft_config) |
|
biases = [bias for bias in biases if bias not in (None, "none")] |
|
if len(biases) > 1: |
|
raise ValueError( |
|
f"{self.__class__.__name__} supports only 1 adapter with bias. When using multiple adapters, " |
|
"set bias to 'none' for all adapters." |
|
) |
|
|
|
@staticmethod |
|
def _check_target_module_exists(config: Configs, key: str): |
|
return check_target_module_exists(config, key) |
|
|
|
def _create_and_replace( |
|
self, |
|
config: Configs, |
|
*args: Any, |
|
**kwargs: Any, |
|
) -> None: |
|
if isinstance(config, adalora.AdaLoraConfig): |
|
adalora.AdaLoraModel._create_and_replace(self, config, *args, **kwargs) |
|
elif isinstance(config, lora.LoraConfig): |
|
lora.LoraModel._create_and_replace(self, config, *args, **kwargs) |
|
elif isinstance(config, loha.LoHaConfig): |
|
loha.LoHaModel._create_and_replace(self, config, *args, **kwargs) |
|
elif isinstance(config, lokr.LoKrConfig): |
|
lokr.LoKrModel._create_and_replace(self, config, *args, **kwargs) |
|
elif isinstance(config, oft.OFTConfig): |
|
oft.OFTModel._create_and_replace(self, config, *args, **kwargs) |
|
else: |
|
raise ValueError(f"Unsupported config type {type(config)}, should be one of {COMPATIBLE_TUNER_TYPES}.") |
|
|
|
def _replace_module(self, parent, child_name, new_module, child) -> None: |
|
setattr(parent, child_name, new_module) |
|
|
|
|
|
|
|
|
|
if hasattr(child, "base_layer"): |
|
child = child.get_base_layer() |
|
elif hasattr(child, "quant_linear_module"): |
|
|
|
child = child.quant_linear_module |
|
|
|
if not hasattr(new_module, "base_layer"): |
|
new_module.weight = child.weight |
|
if hasattr(child, "bias"): |
|
new_module.bias = child.bias |
|
|
|
if getattr(child, "state", None) is not None: |
|
if hasattr(new_module, "base_layer"): |
|
new_module.base_layer.state = child.state |
|
else: |
|
new_module.state = child.state |
|
new_module.to(child.weight.device) |
|
|
|
|
|
for name, module in new_module.named_modules(): |
|
if any(prefix in name for prefix in PREFIXES): |
|
module.to(child.weight.device) |
|
if "ranknum" in name: |
|
module.to(child.weight.device) |
|
|
|
def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: |
|
for n, p in model.named_parameters(): |
|
if not any(prefix in n for prefix in PREFIXES): |
|
p.requires_grad = False |
|
|
|
for active_adapter in self.active_adapters: |
|
bias = getattr(self.peft_config[active_adapter], "bias", "none") |
|
if bias == "none": |
|
continue |
|
|
|
if bias == "all": |
|
for n, p in model.named_parameters(): |
|
if "bias" in n: |
|
p.requires_grad = True |
|
elif bias == "lora_only": |
|
|
|
for m in model.modules(): |
|
if isinstance(m, Layers) and hasattr(m, "bias") and m.bias is not None: |
|
m.bias.requires_grad = True |
|
else: |
|
raise ValueError(f"Requested bias: {bias}, is not implemented.") |
|
|
|
@staticmethod |
|
def _create_new_module(config, adapter_name, target, **kwargs): |
|
gptq_quantization_config = kwargs.get("gptq_quantization_config", None) |
|
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config) |
|
if (gptq_quantization_config is not None) or (AutoGPTQQuantLinear is not None): |
|
raise ValueError(f"GPTQ quantization not supported for {config.peft_type.value} (yet).") |
|
|
|
loaded_in_8bit = kwargs.pop("loaded_in_8bit", False) |
|
loaded_in_4bit = kwargs.pop("loaded_in_4bit", False) |
|
if loaded_in_8bit or loaded_in_4bit: |
|
raise ValueError(f"8bit and 4bit quantization not supported for {config.peft_type.value} (yet).") |
|
|
|
if isinstance(config, adalora.AdaLoraConfig): |
|
new_module = adalora.AdaLoraModel._create_new_module(config, adapter_name, target, **kwargs) |
|
elif isinstance(config, lora.LoraConfig): |
|
new_module = lora.LoraModel._create_new_module(config, adapter_name, target, **kwargs) |
|
elif isinstance(config, loha.LoHaConfig): |
|
new_module = loha.LoHaModel._create_new_module(config, adapter_name, target, **kwargs) |
|
elif isinstance(config, lokr.LoKrConfig): |
|
new_module = lokr.LoKrModel._create_new_module(config, adapter_name, target, **kwargs) |
|
elif isinstance(config, oft.OFTConfig): |
|
new_module = oft.OFTModel._create_new_module(config, adapter_name, target, **kwargs) |
|
else: |
|
raise ValueError(f"Unknown config type {type(config)}, should be one of {COMPATIBLE_TUNER_TYPES}.") |
|
return new_module |
|
|
|
def __getattr__(self, name: str): |
|
"""Forward missing attributes to the wrapped module.""" |
|
try: |
|
return super().__getattr__(name) |
|
except AttributeError: |
|
return getattr(self.model, name) |
|
|
|
def _set_adapter_layers(self, enabled=True): |
|
for module in self.model.modules(): |
|
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): |
|
module.enable_adapters(enabled) |
|
|
|
def enable_adapter_layers(self): |
|
self._set_adapter_layers(enabled=True) |
|
|
|
def disable_adapter_layers(self): |
|
for active_adapter in self.active_adapters: |
|
val = getattr(self.peft_config[active_adapter], "bias", "none") |
|
if val != "none": |
|
msg = ( |
|
f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same " |
|
"output as the the base model would without adaption." |
|
) |
|
warnings.warn(msg) |
|
self._set_adapter_layers(enabled=False) |
|
|
|
def set_adapter(self, adapter_name: Union[str, list[str]]) -> None: |
|
for module in self.model.modules(): |
|
if isinstance(module, Layers): |
|
if module.merged: |
|
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") |
|
module.unmerge() |
|
module.set_adapter(adapter_name) |
|
self.active_adapter = adapter_name |
|
|
|
@staticmethod |
|
def _prepare_adapter_config(peft_config, model_config): |
|
if peft_config.target_modules is None: |
|
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: |
|
raise ValueError("Please specify `target_modules` in `peft_config`") |
|
|
|
peft_config.target_modules = set( |
|
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]] |
|
) |
|
return peft_config |
|
|
|
def _unload_and_optionally_merge( |
|
self, |
|
merge=True, |
|
progressbar: bool = False, |
|
safe_merge: bool = False, |
|
adapter_names: Optional[list[str]] = None, |
|
): |
|
if merge: |
|
if getattr(self.model, "quantization_method", None) == "gptq": |
|
raise ValueError("Cannot merge layers when the model is gptq quantized") |
|
|
|
def merge_recursively(module): |
|
|
|
path = [] |
|
layer = module |
|
while hasattr(layer, "base_layer"): |
|
path.append(layer) |
|
layer = layer.base_layer |
|
for layer_before, layer_after in zip(path[:-1], path[1:]): |
|
layer_after.merge(safe_merge=safe_merge, adapter_names=adapter_names) |
|
layer_before.base_layer = layer_after.base_layer |
|
module.merge(safe_merge=safe_merge, adapter_names=adapter_names) |
|
|
|
key_list = [key for key, _ in self.model.named_modules() if not any(prefix in key for prefix in PREFIXES)] |
|
desc = "Unloading " + ("and merging " if merge else "") + "model" |
|
|
|
for key in tqdm(key_list, disable=not progressbar, desc=desc): |
|
try: |
|
parent, target, target_name = _get_submodules(self.model, key) |
|
except AttributeError: |
|
continue |
|
|
|
if hasattr(target, "base_layer"): |
|
if merge: |
|
merge_recursively(target) |
|
self._replace_module(parent, target_name, target.get_base_layer(), target) |
|
elif isinstance(target, ModulesToSaveWrapper): |
|
|
|
new_module = target.modules_to_save[target.active_adapter] |
|
if hasattr(new_module, "base_layer"): |
|
|
|
if merge: |
|
new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names) |
|
new_module = new_module.get_base_layer() |
|
setattr(parent, target_name, new_module) |
|
|
|
return self.model |
|
|
|
def add_weighted_adapter(self, *args: Any, **kwargs: Any) -> None: |
|
raise NotImplementedError(f"Weighted adapters are not supported for {self.__class__.__name__} (yet).") |
|
|
|
def delete_adapter(self, adapter_name: Union[str, list[str]]) -> None: |
|
""" |
|
Deletes an existing adapter. |
|
|
|
Args: |
|
adapter_name (Union[str, list[str]]): Name of the adapter(s) to delete. |
|
""" |
|
if isinstance(adapter_name, str): |
|
adapter_names = [adapter_name] |
|
else: |
|
adapter_names = adapter_name |
|
|
|
mismatched = set(adapter_names) - set(self.peft_config.keys()) |
|
if mismatched: |
|
raise ValueError( |
|
f"Adapter(s) {sorted(mismatched)} not found, available adapters: {sorted(self.peft_config.keys())}" |
|
) |
|
|
|
for adapter_name in adapter_names: |
|
del self.peft_config[adapter_name] |
|
|
|
key_list = [key for key, _ in self.model.named_modules() if not any(prefix in key for prefix in PREFIXES)] |
|
new_adapter = None |
|
for key in key_list: |
|
_, target, _ = _get_submodules(self.model, key) |
|
if isinstance(target, BaseTunerLayer): |
|
target.delete_adapter(adapter_name) |
|
if new_adapter is None: |
|
new_adapter = target.active_adapters[:] |
|
|
|
self.active_adapter = new_adapter or [] |
|
|
|
def merge_and_unload( |
|
self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None |
|
) -> nn.Module: |
|
r""" |
|
This method merges the layers into the base model. This is needed if someone wants to use the base model as a |
|
standalone model. |
|
|
|
Args: |
|
progressbar (`bool`): |
|
whether to show a progressbar indicating the unload and merge process |
|
safe_merge (`bool`): |
|
whether to activate the safe merging check to check if there is any potential Nan in the adapter |
|
weights |
|
adapter_names (`List[str]`, *optional*): |
|
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults |
|
to `None`. |
|
""" |
|
return self._unload_and_optionally_merge( |
|
progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names |
|
) |
|
|
|
def unload(self) -> nn.Module: |
|
""" |
|
Gets back the base model by removing all the lora modules without merging. This gives back the original base |
|
model. |
|
""" |
|
return self._unload_and_optionally_merge(merge=False) |
|
|
|
def generate(self, *args: Any, **kwargs: Any): |
|
return self.model.generate(*args, **kwargs) |
|
|