Spaces:
Running
on
Zero
Running
on
Zero
File size: 38,671 Bytes
d711508 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 |
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
import operator
import re
import warnings
from contextlib import contextmanager
from dataclasses import asdict, replace
from enum import Enum
from functools import partial, reduce
from itertools import chain
from typing import Literal, Optional
import torch
from torch import nn
from tqdm import tqdm
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.tuners.tuners_utils import (
BaseTuner,
BaseTunerLayer,
check_target_module_exists,
onload_layer,
replicate_layers,
)
from peft.utils import (
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
ModulesToSaveWrapper,
_freeze_adapter,
_get_submodules,
get_peft_model_state_dict,
get_quantization_config,
)
from peft.utils.merge_utils import dare_linear, dare_ties, magnitude_prune, task_arithmetic, ties
from .aqlm import dispatch_aqlm
from .awq import dispatch_awq
from .config import LoraConfig
from .eetq import dispatch_eetq
from .gptq import dispatch_gptq
from .hqq import dispatch_hqq
from .layer import Conv2d, LoraLayer, dispatch_default
from .tp_layer import dispatch_megatron
def _adapter_names_pre_forward_hook(target, args, kwargs, adapter_names):
# pre-forward hook to inject the adapter_names argument when using mixed adapter batches inference
kwargs["adapter_names"] = adapter_names
return args, kwargs
class LoraModel(BaseTuner):
"""
Creates Low Rank Adapter (LoRA) model from a pretrained transformers model.
The method is described in detail in https://arxiv.org/abs/2106.09685.
Args:
model ([`torch.nn.Module`]): The model to be adapted.
config ([`LoraConfig`]): The configuration of the Lora model.
adapter_name (`str`): The name of the adapter, defaults to `"default"`.
Returns:
`torch.nn.Module`: The Lora model.
Example:
```py
>>> from transformers import AutoModelForSeq2SeqLM
>>> from peft import LoraModel, LoraConfig
>>> config = LoraConfig(
... task_type="SEQ_2_SEQ_LM",
... r=8,
... lora_alpha=32,
... target_modules=["q", "v"],
... lora_dropout=0.01,
... )
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> lora_model = LoraModel(model, config, "default")
```
```py
>>> import torch
>>> import transformers
>>> from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
>>> rank = ...
>>> target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc_in", "fc_out", "wte"]
>>> config = LoraConfig(
... r=4, lora_alpha=16, target_modules=target_modules, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM"
... )
>>> quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
>>> tokenizer = transformers.AutoTokenizer.from_pretrained(
... "kakaobrain/kogpt",
... revision="KoGPT6B-ryan1.5b-float16", # or float32 version: revision=KoGPT6B-ryan1.5b
... bos_token="[BOS]",
... eos_token="[EOS]",
... unk_token="[UNK]",
... pad_token="[PAD]",
... mask_token="[MASK]",
... )
>>> model = transformers.GPTJForCausalLM.from_pretrained(
... "kakaobrain/kogpt",
... revision="KoGPT6B-ryan1.5b-float16", # or float32 version: revision=KoGPT6B-ryan1.5b
... pad_token_id=tokenizer.eos_token_id,
... use_cache=False,
... device_map={"": rank},
... torch_dtype=torch.float16,
... quantization_config=quantization_config,
... )
>>> model = prepare_model_for_kbit_training(model)
>>> lora_model = get_peft_model(model, config)
```
**Attributes**:
- **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted.
- **peft_config** ([`LoraConfig`]): The configuration of the Lora model.
"""
prefix: str = "lora_"
def __init__(self, model, config, adapter_name) -> None:
super().__init__(model, config, adapter_name)
def _check_new_adapter_config(self, config: LoraConfig) -> None:
"""
A helper method to check the config when a new adapter is being added.
Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters.
"""
# TODO: there should be a check if any of the existing adapters actually has bias != "none", or else the check
# does not fully correspond to the error message.
if (len(self.peft_config) > 1) and (config.bias != "none"):
raise ValueError(
f"{self.__class__.__name__} supports only 1 adapter with bias. When using multiple adapters, "
"set bias to 'none' for all adapters."
)
@staticmethod
def _check_target_module_exists(lora_config, key):
return check_target_module_exists(lora_config, key)
def _prepare_model(self, peft_config: LoraConfig, model: nn.Module):
r"""
A private method to modify the model structure before adapter is applied.
Args:
peft_config (`PeftConfig`):
The prepared adapter config.
model (`nn.Module`):
The model that is going to be adapted.
"""
if peft_config.layer_replication:
replicate_layers(model, peft_config.layer_replication)
def _create_and_replace(
self,
lora_config,
adapter_name,
target,
target_name,
parent,
current_key,
):
if current_key is None:
raise ValueError("Current Key shouldn't be `None`")
# Regexp matching - Find key which matches current target_name in patterns provided
pattern_keys = list(chain(lora_config.rank_pattern.keys(), lora_config.alpha_pattern.keys()))
target_name_key = next(filter(lambda key: re.match(rf".*\.{key}$", current_key), pattern_keys), current_key)
r = lora_config.rank_pattern.get(target_name_key, lora_config.r)
alpha = lora_config.alpha_pattern.get(target_name_key, lora_config.lora_alpha)
kwargs = {
"r": r,
"lora_alpha": alpha,
"lora_dropout": lora_config.lora_dropout,
"fan_in_fan_out": lora_config.fan_in_fan_out,
"init_lora_weights": lora_config.init_lora_weights,
"use_rslora": lora_config.use_rslora,
"use_dora": lora_config.use_dora,
"loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False),
"loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False),
}
quant_methods = ["gptq", "aqlm", "awq"]
for quant_method in quant_methods:
quantization_config = get_quantization_config(self.model, method=quant_method)
if quantization_config is not None:
kwargs[f"{quant_method}_quantization_config"] = quantization_config
# note: AdaLoraLayer is a subclass of LoraLayer, we need to exclude it
from peft.tuners.adalora import AdaLoraLayer
if isinstance(target, LoraLayer) and not isinstance(target, AdaLoraLayer):
target.update_layer(
adapter_name,
r,
lora_alpha=alpha,
lora_dropout=lora_config.lora_dropout,
init_lora_weights=lora_config.init_lora_weights,
use_rslora=lora_config.use_rslora,
use_dora=lora_config.use_dora,
)
else:
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
if adapter_name not in self.active_adapters:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)
def _replace_module(self, parent, child_name, new_module, child):
setattr(parent, child_name, new_module)
# It's not necessary to set requires_grad here, as that is handled by
# _mark_only_adapters_as_trainable
# child layer wraps the original module, unpack it
if hasattr(child, "base_layer"):
child = child.base_layer
if not hasattr(new_module, "base_layer"):
new_module.weight = child.weight
if hasattr(child, "bias"):
new_module.bias = child.bias
if getattr(child, "state", None) is not None:
if hasattr(new_module, "base_layer"):
new_module.base_layer.state = child.state
else:
new_module.state = child.state
new_module.to(child.weight.device)
# dispatch to correct device
for name, module in new_module.named_modules():
if (self.prefix in name) or ("ranknum" in name):
weight = (
child.qweight
if hasattr(child, "qweight")
else child.W_q
if hasattr(child, "W_q")
else child.weight
)
module.to(weight.device)
def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None:
for n, p in model.named_parameters():
if self.prefix not in n:
p.requires_grad = False
for active_adapter in self.active_adapters:
bias = self.peft_config[active_adapter].bias
if bias == "none":
continue
if bias == "all":
for n, p in model.named_parameters():
if "bias" in n:
p.requires_grad = True
elif bias == "lora_only":
for m in model.modules():
if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None:
m.bias.requires_grad = True
else:
raise NotImplementedError(f"Requested bias: {bias}, is not implemented.")
@staticmethod
def _create_new_module(lora_config, adapter_name, target, **kwargs):
# Collect dispatcher functions to decide what backend to use for the replaced LoRA layer. The order matters,
# because the first match is always used. Therefore, the default layers should be checked last.
dispatchers = []
# avoid eager bnb import
if is_bnb_available():
from .bnb import dispatch_bnb_8bit
dispatchers.append(dispatch_bnb_8bit)
if is_bnb_4bit_available():
from .bnb import dispatch_bnb_4bit
dispatchers.append(dispatch_bnb_4bit)
dispatchers.extend(
[
dispatch_eetq,
dispatch_aqlm,
dispatch_awq,
dispatch_gptq,
dispatch_hqq,
dispatch_megatron,
dispatch_default,
]
)
new_module = None
for dispatcher in dispatchers:
new_module = dispatcher(target, adapter_name, lora_config=lora_config, **kwargs)
if new_module is not None: # first match wins
break
if new_module is None:
# no module could be matched
raise ValueError(
f"Target module {target} is not supported. Currently, only the following modules are supported: "
"`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`."
)
return new_module
def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.model, name)
def get_peft_config_as_dict(self, inference: bool = False):
config_dict = {}
for key, value in self.peft_config.items():
config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()}
if inference:
config["inference_mode"] = True
config_dict[key] = config
return config
def _set_adapter_layers(self, enabled: bool = True) -> None:
for module in self.model.modules():
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
module.enable_adapters(enabled)
def enable_adapter_layers(self) -> None:
"""Enable all adapters.
Call this if you have previously disabled all adapters and want to re-enable them.
"""
self._set_adapter_layers(enabled=True)
def disable_adapter_layers(self) -> None:
"""Disable all adapters.
When disabling all adapters, the model output corresponds to the output of the base model.
"""
for active_adapter in self.active_adapters:
val = self.peft_config[active_adapter].bias
if val != "none":
msg = (
f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same "
"output as the the base model would without adaption."
)
warnings.warn(msg)
self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name: str | list[str]) -> None:
"""Set the active adapter(s).
Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is
not desired, use the following code.
```py
>>> for name, param in model_peft.named_parameters():
... if ...: # some check on name (ex. if 'lora' in name)
... param.requires_grad = False
```
Args:
adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated.
"""
for module in self.model.modules():
if isinstance(module, LoraLayer):
if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge()
module.set_adapter(adapter_name)
self.active_adapter = adapter_name
@contextmanager
def _enable_peft_forward_hooks(self, *args, **kwargs):
# If adapter_names is passed as an argument, we inject it into the forward arguments.
adapter_names = kwargs.pop("adapter_names", None)
if adapter_names is None:
# nothing to do
yield
return
if self.training:
raise ValueError("Cannot pass `adapter_names` when the model is in training mode.")
hook_handles = []
for module in self.modules():
if isinstance(module, LoraLayer):
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names)
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
hook_handles.append(handle)
yield
for handle in hook_handles:
handle.remove()
def _check_merge_allowed(self):
"""Verify that the configuration supports merging.
Currently gptq quantization and replicated layers do not support merging.
"""
if getattr(self.model, "quantization_method", None) == "gptq":
raise ValueError("Cannot merge LORA layers when the model is gptq quantized")
if self.peft_config.get("layer_replication"):
raise ValueError("Cannot merge LORA layers when base model layers are replicated")
@staticmethod
def _prepare_adapter_config(peft_config, model_config):
if peft_config.target_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING:
raise ValueError("Please specify `target_modules` in `peft_config`")
peft_config.target_modules = set(
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]]
)
return peft_config
def _unload_and_optionally_merge(
self,
merge=True,
progressbar: bool = False,
safe_merge: bool = False,
adapter_names: Optional[list[str]] = None,
):
if merge:
self._check_merge_allowed()
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
desc = "Unloading " + ("and merging " if merge else "") + "model"
for key in tqdm(key_list, disable=not progressbar, desc=desc):
try:
parent, target, target_name = _get_submodules(self.model, key)
except AttributeError:
continue
with onload_layer(target):
if hasattr(target, "base_layer"):
if merge:
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
self._replace_module(parent, target_name, target.get_base_layer(), target)
elif isinstance(target, ModulesToSaveWrapper):
# save any additional trainable modules part of `modules_to_save`
new_module = target.modules_to_save[target.active_adapter]
if hasattr(new_module, "base_layer"):
# check if the module is itself a tuner layer
if merge:
new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names)
new_module = new_module.get_base_layer()
setattr(parent, target_name, new_module)
return self.model
def _check_add_weighted_adapter(
self, adapters: list[str], combination_type: str, svd_rank: int | None
) -> tuple[str, int, str]:
"""
Helper function to check if the arguments to add_weighted_adapter are valid and compatible with the underlying
model.
"""
for adapter in adapters:
if adapter not in list(self.peft_config.keys()):
raise ValueError(f"Adapter {adapter} does not exist")
# If more than one of the adapters targets the same module with modules_to_save, raise an error, as these
# modules cannot be merged. First, find the ModulesToSaveWrapper instances in the model, then check if they
# have modules for the adapters to be merged.
modules_to_save_wrappers = [module for module in self.modules() if isinstance(module, ModulesToSaveWrapper)]
problematic_wrappers = [
wrapper
for wrapper in modules_to_save_wrappers
if sum(adapter in wrapper.modules_to_save for adapter in adapters) > 1
]
if problematic_wrappers:
raise ValueError(
"Cannot add weighted adapters if they target the same module with modules_to_save, but found "
f"{len(problematic_wrappers)} such instance(s)."
)
# if there is only one adapter, we can only use linear merging
combination_type = "linear" if len(adapters) == 1 else combination_type
adapters_ranks = [self.peft_config[adapter].r for adapter in adapters]
if combination_type in ("linear", "ties", "dare_ties", "dare_linear", "magnitude_prune"):
# all adapters ranks should be same, new rank is just this value
if len(set(adapters_ranks)) != 1:
raise ValueError(
"All adapters must have the same r value when using combination_type linear, ties, dare_ties or "
"dare_linear."
)
new_rank = adapters_ranks[0]
elif combination_type == "cat":
# adapters ranks may be different, new rank is sum of all ranks
# be careful, because output adapter rank may be really big if mixing a lot of adapters
new_rank = sum(adapters_ranks)
elif combination_type.endswith("svd"):
# new rank is the max of all ranks of the adapters if not provided
new_rank = svd_rank or max(adapters_ranks)
else:
raise ValueError(f"Invalid combination_type: {combination_type}")
target_module_types = [type(self.peft_config[adapter].target_modules) for adapter in adapters]
if not target_module_types:
raise ValueError(f"Found no adapter matching the names in {adapters}")
if len(set(target_module_types)) > 1:
raise ValueError(
"all adapter configs should follow the same target modules type. "
"Combining adapters with `target_modules` type being a mix of list/set and string is not supported."
)
if target_module_types[0] == str:
new_target_modules = "|".join(f"({self.peft_config[adapter].target_modules})" for adapter in adapters)
elif target_module_types[0] == set:
new_target_modules = reduce(
operator.or_, (self.peft_config[adapter].target_modules for adapter in adapters)
)
else:
raise TypeError(f"Invalid type {target_module_types[0]} found in target_modules")
return combination_type, new_rank, new_target_modules
def add_weighted_adapter(
self,
adapters: list[str],
weights: list[float],
adapter_name: str,
combination_type: str = "svd",
svd_rank: int | None = None,
svd_clamp: int | None = None,
svd_full_matrices: bool = True,
svd_driver: str | None = None,
density: float | None = None,
majority_sign_method: Literal["total", "frequency"] = "total",
) -> None:
"""
This method adds a new adapter by merging the given adapters with the given weights.
When using the `cat` combination_type you should be aware that rank of the resulting adapter will be equal to
the sum of all adapters ranks. So it's possible that the mixed adapter may become too big and result in OOM
errors.
Args:
adapters (`list`):
List of adapter names to be merged.
weights (`list`):
List of weights for each adapter.
adapter_name (`str`):
Name of the new adapter.
combination_type (`str`):
The merging type can be one of [`svd`, `linear`, `cat`, `ties`, `ties_svd`, `dare_ties`, `dare_linear`,
`dare_ties_svd`, `dare_linear_svd`, `magnitude_prune`, `magnitude_prune_svd`]. When using the `cat`
combination_type, the rank of the resulting adapter is equal to the sum of all adapters ranks (the
mixed adapter may be too big and result in OOM errors).
svd_rank (`int`, *optional*):
Rank of output adapter for svd. If None provided, will use max rank of merging adapters.
svd_clamp (`float`, *optional*):
A quantile threshold for clamping SVD decomposition output. If None is provided, do not perform
clamping. Defaults to None.
svd_full_matrices (`bool`, *optional*):
Controls whether to compute the full or reduced SVD, and consequently, the shape of the returned
tensors U and Vh. Defaults to True.
svd_driver (`str`, *optional*):
Name of the cuSOLVER method to be used. This keyword argument only works when merging on CUDA. Can be
one of [None, `gesvd`, `gesvdj`, `gesvda`]. For more info please refer to `torch.linalg.svd`
documentation. Defaults to None.
density (`float`, *optional*):
Value between 0 and 1. 0 means all values are pruned and 1 means no values are pruned. Should be used
with [`ties`, `ties_svd`, `dare_ties`, `dare_linear`, `dare_ties_svd`, `dare_linear_svd`,
`magnintude_prune`, `magnitude_prune_svd`]
majority_sign_method (`str`):
The method, should be one of ["total", "frequency"], to use to get the magnitude of the sign values.
Should be used with [`ties`, `ties_svd`, `dare_ties`, `dare_ties_svd`]
"""
if adapter_name in list(self.peft_config.keys()):
return
for adapter in adapters:
if adapter not in list(self.peft_config.keys()):
raise ValueError(f"Adapter {adapter} does not exist")
combination_type, new_rank, new_target_modules = self._check_add_weighted_adapter(
adapters=adapters,
combination_type=combination_type,
svd_rank=svd_rank,
)
self.peft_config[adapter_name] = replace(
self.peft_config[adapters[0]],
r=new_rank,
lora_alpha=new_rank,
target_modules=new_target_modules,
)
self.inject_adapter(self.model, adapter_name)
# Do we really need that?
_freeze_adapter(self.model, adapter_name)
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, LoraLayer):
if adapter_name in target.lora_A:
target_lora_A = target.lora_A[adapter_name].weight
target_lora_B = target.lora_B[adapter_name].weight
elif adapter_name in target.lora_embedding_A:
target_lora_A = target.lora_embedding_A[adapter_name]
target_lora_B = target.lora_embedding_B[adapter_name]
else:
continue
target_lora_A.data = target_lora_A.data * 0.0
target_lora_B.data = target_lora_B.data * 0.0
if combination_type == "cat":
loras_A, loras_B = [], []
for adapter, weight in zip(adapters, weights):
if adapter in target.lora_A:
current_adapter_lora_A = target.lora_A[adapter].weight
current_adapter_lora_B = target.lora_B[adapter].weight
elif adapter in target.lora_embedding_A:
current_adapter_lora_A = target.lora_embedding_A[adapter]
current_adapter_lora_B = target.lora_embedding_B[adapter]
else:
continue
loras_A.append(current_adapter_lora_A.data * weight * target.scaling[adapter])
loras_B.append(current_adapter_lora_B.data)
if len(loras_A) == 0:
raise ValueError("No matching LoRAs found. Please raise an issue on GitHub.")
loras_A = torch.cat(loras_A, dim=0)
loras_B = torch.cat(loras_B, dim=1)
target_lora_A.data[: loras_A.shape[0], :] = loras_A
target_lora_B.data[:, : loras_B.shape[1]] = loras_B
elif combination_type in [
"svd",
"ties_svd",
"dare_linear_svd",
"dare_ties_svd",
"magnitude_prune_svd",
]:
target_lora_A.data, target_lora_B.data = self._svd_generalized_task_arithmetic_weighted_adapter(
combination_type,
adapters,
weights,
new_rank,
target,
target_lora_A,
target_lora_B,
density,
majority_sign_method,
svd_clamp,
full_matrices=svd_full_matrices,
driver=svd_driver,
)
elif combination_type in ["linear", "ties", "dare_linear", "dare_ties", "magnitude_prune"]:
target_lora_A.data, target_lora_B.data = self._generalized_task_arithmetic_weighted_adapter(
combination_type, adapters, weights, target, density, majority_sign_method
)
def _svd_generalized_task_arithmetic_weighted_adapter(
self,
combination_type,
adapters,
weights,
new_rank,
target,
target_lora_A,
target_lora_B,
density,
majority_sign_method,
clamp=None,
full_matrices=True,
driver=None,
):
valid_adapters = []
valid_weights = []
is_embedding = any(adapter in target.lora_embedding_A for adapter in adapters)
for adapter, weight in zip(adapters, weights):
if adapter in target.lora_A or adapter in target.lora_embedding_A:
valid_adapters.append(adapter)
valid_weights.append(weight * target.scaling[adapter])
# if no valid adapter, nothing to do
if len(valid_adapters) == 0:
raise ValueError("No matching LoRAs found. Please raise an issue on Github.")
delta_weight = [target.get_delta_weight(adapter) for adapter in valid_adapters]
valid_weights = torch.tensor(valid_weights).to(delta_weight[0].device)
if combination_type == "svd":
delta_weight = task_arithmetic(delta_weight, valid_weights)
elif combination_type == "ties_svd":
delta_weight = ties(delta_weight, valid_weights, density, majority_sign_method)
elif combination_type == "dare_linear_svd":
delta_weight = dare_linear(delta_weight, valid_weights, density)
elif combination_type == "dare_ties_svd":
delta_weight = dare_ties(delta_weight, valid_weights, density, majority_sign_method)
elif combination_type == "magnitude_prune_svd":
delta_weight = magnitude_prune(delta_weight, valid_weights, density)
else:
raise ValueError(f"Invalid value passed to combination type: {combination_type}")
conv2d = isinstance(target, Conv2d)
if conv2d:
conv2d_1x1 = target.weight.size()[2:4] == (1, 1)
if not conv2d_1x1:
delta_weight = delta_weight.flatten(start_dim=1)
else:
delta_weight = delta_weight.squeeze()
if (hasattr(target, "fan_in_fan_out") and target.fan_in_fan_out) or is_embedding:
delta_weight = delta_weight.T
# based on https://github.com/kohya-ss/sd-scripts/blob/main/networks/svd_merge_lora.py#L114-L131
U, S, Vh = torch.linalg.svd(delta_weight, full_matrices=full_matrices, driver=driver)
U = U[:, :new_rank]
S = S[:new_rank]
U = U @ torch.diag(S)
Vh = Vh[:new_rank, :]
if clamp is not None:
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, clamp)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
if conv2d:
U = U.reshape(target_lora_B.data.shape)
Vh = Vh.reshape(target_lora_A.data.shape)
return Vh, U
def _generalized_task_arithmetic_weighted_adapter(
self,
combination_type,
adapters,
weights,
target,
density,
majority_sign_method,
):
# account weights for LoRA A and B layers.
valid_weights = []
lora_A_deltas = []
lora_B_deltas = []
for adapter, weight in zip(adapters, weights):
if adapter in target.lora_A:
current_adapter_lora_A = target.lora_A[adapter].weight
current_adapter_lora_B = target.lora_B[adapter].weight
elif adapter in target.lora_embedding_A:
current_adapter_lora_A = target.lora_embedding_A[adapter]
current_adapter_lora_B = target.lora_embedding_B[adapter]
else:
continue
valid_weights.append(math.sqrt(weight * target.scaling[adapter]))
lora_A_deltas.append(current_adapter_lora_A.data)
lora_B_deltas.append(current_adapter_lora_B.data)
valid_weights = torch.tensor(valid_weights).to(lora_A_deltas[0].device)
lora_deltas = [lora_A_deltas, lora_B_deltas]
dtype = lora_A_deltas[0].dtype
for i, task_tensors in enumerate(lora_deltas):
if combination_type == "linear":
lora_deltas[i] = task_arithmetic(task_tensors, valid_weights)
elif combination_type == "ties":
lora_deltas[i] = ties(task_tensors, valid_weights, density, majority_sign_method)
elif combination_type == "dare_linear":
lora_deltas[i] = dare_linear(task_tensors, valid_weights, density)
elif combination_type == "dare_ties":
lora_deltas[i] = dare_ties(task_tensors, valid_weights, density, majority_sign_method)
elif combination_type == "magnitude_prune":
lora_deltas[i] = magnitude_prune(task_tensors, valid_weights, density)
else:
raise ValueError("Invalid combination type")
lora_deltas = [delta.to(dtype) for delta in lora_deltas]
return lora_deltas
def delete_adapter(self, adapter_name: str) -> None:
"""
Deletes an existing adapter.
Args:
adapter_name (str): Name of the adapter to be deleted.
"""
if adapter_name not in list(self.peft_config.keys()):
raise ValueError(f"Adapter {adapter_name} does not exist")
del self.peft_config[adapter_name]
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
new_adapter = None
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, LoraLayer):
target.delete_adapter(adapter_name)
if new_adapter is None:
new_adapter = target.active_adapters[:]
self.active_adapter = new_adapter or []
def merge_and_unload(
self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
) -> torch.nn.Module:
r"""
This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model
as a standalone model.
Args:
progressbar (`bool`):
whether to show a progressbar indicating the unload and merge process
safe_merge (`bool`):
whether to activate the safe merging check to check if there is any potential Nan in the adapter
weights
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
Example:
```py
>>> from transformers import AutoModelForCausalLM
>>> from peft import PeftModel
>>> base_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-40b")
>>> peft_model_id = "smangrul/falcon-40B-int4-peft-lora-sfttrainer-sample"
>>> model = PeftModel.from_pretrained(base_model, peft_model_id)
>>> merged_model = model.merge_and_unload()
```
"""
return self._unload_and_optionally_merge(
progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names
)
def unload(self) -> torch.nn.Module:
"""
Gets back the base model by removing all the lora modules without merging. This gives back the original base
model.
"""
return self._unload_and_optionally_merge(merge=False)
def subtract_pissa_init(
self, output_state_dict: dict[str, torch.Tensor], adapter_name: str = "pissa_init", kwargs=None
):
"""
This function can calculate the updates of the PiSSA by comparing the parameters of the PiSSA adapter in
`output_state_dict` with the initial values of PiSSA in `adapter_name`, thus converting PiSSA to LoRA.
"""
for name, param in self.model.named_parameters():
if (
param.data.dtype != torch.float32
and param.data.dtype != torch.float16
and param.data.dtype != torch.bfloat16
):
warnings.warn(
r"Note that Quant(W_res) + AB != Quant(W) + \Delta(AB); "
"the converted LoRA, when combined with W or Quant(W), may introduce a certain gap in the fine-tuned model. "
"Therefore, we recommend directly using the Quant(W_res) in conjunction with the PiSSA adapter. "
)
pissa_init_state_dict = get_peft_model_state_dict(
self,
state_dict=kwargs.get("state_dict", None),
adapter_name=adapter_name,
)
tensors_lora = {}
for name in output_state_dict.keys():
## W = W^{res} + A_0 \times B_0,
## W + \Delta W = W^{res} + A \times B,
## \Delta W = A \times B - A_0 \times B_0 = [A | A_0] \times [B | -B_0]^T = A'B'.
if "lora_A" in name:
tensors_lora[name] = torch.cat(
[output_state_dict[name], pissa_init_state_dict[".".join(name.split(".")[1:])]], dim=0
)
elif "lora_B" in name:
tensors_lora[name] = torch.cat(
[output_state_dict[name], -pissa_init_state_dict[".".join(name.split(".")[1:])]], dim=1
)
return tensors_lora
|