Clemspace's picture
Initial model upload
cb9e677
from typing import NamedTuple
import torch
import torch.nn as nn
class LoRALinear(nn.Module):
"""
Implementation of:
- LoRA: https://arxiv.org/abs/2106.09685
Notes:
- Freezing is handled at network level, not layer level.
- Scaling factor controls relative importance of LoRA skip
connection versus original frozen weight. General guidance is
to keep it to 2.0 and sweep over learning rate when changing
the rank.
"""
def __init__(
self,
in_features: int,
out_features: int,
rank: int,
scaling: float,
dropout: float,
bias: bool = False,
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
assert not bias
self.bias = bias
self.rank = rank
self.scaling = scaling
self.dropout = nn.Dropout(p=dropout)
self.lora_A = nn.Linear(
self.in_features,
self.rank,
bias=self.bias,
)
self.lora_B = nn.Linear(
self.rank,
self.out_features,
bias=self.bias,
)
self.frozen_W = nn.Linear(self.in_features, self.out_features, bias=self.bias)
# make sure no LoRA weights are marked as "missing" in load_state_dict
def ignore_missing_keys(m: nn.Module, incompatible_keys: NamedTuple):
# empty missing keys in place
incompatible_keys.missing_keys[:] = [] # type: ignore
self.register_load_state_dict_post_hook(ignore_missing_keys)
def merge_weight(self):
with torch.no_grad():
down_weight = self.lora_A.weight
up_weight = self.lora_B.weight
weight = up_weight.mm(down_weight) * self.scaling
weight += self.frozen_W.weight
return weight
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
key_name = prefix + "weight"
# full checkpoint
if key_name in state_dict:
w_ref = state_dict[key_name]
# load frozen weights
self.frozen_W.load_state_dict({"weight": w_ref}, assign=True)
def forward(self, x: torch.Tensor):
lora = self.lora_B(self.lora_A(self.dropout(x)))
return self.frozen_W(x) + lora * self.scaling
def __repr__(self) -> str:
return "{}Linear(in_features={}, out_features={}, r={}, dropout={})".format(
"LoRA", self.in_features, self.out_features, self.rank, self.dropout.p
)