|
from typing import NamedTuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class LoRALinear(nn.Module): |
|
""" |
|
Implementation of: |
|
- LoRA: https://arxiv.org/abs/2106.09685 |
|
|
|
Notes: |
|
- Freezing is handled at network level, not layer level. |
|
- Scaling factor controls relative importance of LoRA skip |
|
connection versus original frozen weight. General guidance is |
|
to keep it to 2.0 and sweep over learning rate when changing |
|
the rank. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_features: int, |
|
out_features: int, |
|
rank: int, |
|
scaling: float, |
|
dropout: float, |
|
bias: bool = False, |
|
): |
|
super().__init__() |
|
|
|
self.in_features = in_features |
|
self.out_features = out_features |
|
assert not bias |
|
self.bias = bias |
|
self.rank = rank |
|
self.scaling = scaling |
|
|
|
self.dropout = nn.Dropout(p=dropout) |
|
|
|
self.lora_A = nn.Linear( |
|
self.in_features, |
|
self.rank, |
|
bias=self.bias, |
|
) |
|
self.lora_B = nn.Linear( |
|
self.rank, |
|
self.out_features, |
|
bias=self.bias, |
|
) |
|
|
|
self.frozen_W = nn.Linear(self.in_features, self.out_features, bias=self.bias) |
|
|
|
|
|
def ignore_missing_keys(m: nn.Module, incompatible_keys: NamedTuple): |
|
|
|
incompatible_keys.missing_keys[:] = [] |
|
|
|
self.register_load_state_dict_post_hook(ignore_missing_keys) |
|
|
|
def merge_weight(self): |
|
with torch.no_grad(): |
|
down_weight = self.lora_A.weight |
|
up_weight = self.lora_B.weight |
|
|
|
weight = up_weight.mm(down_weight) * self.scaling |
|
|
|
weight += self.frozen_W.weight |
|
return weight |
|
|
|
def _load_from_state_dict( |
|
self, |
|
state_dict, |
|
prefix, |
|
local_metadata, |
|
strict, |
|
missing_keys, |
|
unexpected_keys, |
|
error_msgs, |
|
): |
|
key_name = prefix + "weight" |
|
|
|
|
|
if key_name in state_dict: |
|
w_ref = state_dict[key_name] |
|
|
|
|
|
self.frozen_W.load_state_dict({"weight": w_ref}, assign=True) |
|
|
|
def forward(self, x: torch.Tensor): |
|
lora = self.lora_B(self.lora_A(self.dropout(x))) |
|
return self.frozen_W(x) + lora * self.scaling |
|
|
|
def __repr__(self) -> str: |
|
return "{}Linear(in_features={}, out_features={}, r={}, dropout={})".format( |
|
"LoRA", self.in_features, self.out_features, self.rank, self.dropout.p |
|
) |
|
|