|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import math |
|
from typing import Dict, List |
|
|
|
import lit_llama.model as llama |
|
|
|
from contextlib import contextmanager |
|
from dataclasses import dataclass |
|
|
|
|
|
class LoRALayer(): |
|
def __init__( |
|
self, |
|
r: int, |
|
lora_alpha: int, |
|
lora_dropout: float, |
|
merge_weights: bool, |
|
): |
|
self.r = r |
|
self.lora_alpha = lora_alpha |
|
|
|
if lora_dropout > 0.: |
|
self.lora_dropout = nn.Dropout(p=lora_dropout) |
|
else: |
|
self.lora_dropout = lambda x: x |
|
|
|
self.merged = False |
|
self.merge_weights = merge_weights |
|
|
|
|
|
class MergedLinear(nn.Linear, LoRALayer): |
|
|
|
def __init__( |
|
self, |
|
in_features: int, |
|
out_features: int, |
|
r: int = 0, |
|
lora_alpha: int = 1, |
|
lora_dropout: float = 0., |
|
enable_lora: List[bool] = [False], |
|
fan_in_fan_out: bool = False, |
|
merge_weights: bool = True, |
|
**kwargs |
|
): |
|
nn.Linear.__init__(self, in_features, out_features, **kwargs) |
|
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, |
|
merge_weights=merge_weights) |
|
assert out_features % len(enable_lora) == 0, \ |
|
'The length of enable_lora must divide out_features' |
|
self.enable_lora = enable_lora |
|
self.fan_in_fan_out = fan_in_fan_out |
|
|
|
if r > 0 and any(enable_lora): |
|
self.lora_A = nn.Parameter( |
|
self.weight.new_zeros((r * sum(enable_lora), in_features))) |
|
self.lora_B = nn.Parameter( |
|
self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r)) |
|
) |
|
self.scaling = self.lora_alpha / self.r |
|
|
|
self.weight.requires_grad = False |
|
|
|
self.lora_ind = self.weight.new_zeros( |
|
(out_features, ), dtype=torch.bool |
|
).view(len(enable_lora), -1) |
|
self.lora_ind[enable_lora, :] = True |
|
self.lora_ind = self.lora_ind.view(-1) |
|
self.reset_parameters() |
|
if fan_in_fan_out: |
|
self.weight.data = self.weight.data.T |
|
|
|
def reset_parameters(self): |
|
nn.Linear.reset_parameters(self) |
|
if hasattr(self, 'lora_A'): |
|
|
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
|
nn.init.zeros_(self.lora_B) |
|
|
|
def zero_pad(self, x): |
|
x = x.transpose(0, 1) |
|
result = x.new_zeros((*x.shape[:-1], self.out_features)) |
|
result = result.view(-1, self.out_features) |
|
result[:, self.lora_ind] = x.reshape( |
|
-1, self.out_features // len(self.enable_lora) * sum(self.enable_lora) |
|
) |
|
return result.view((*x.shape[:-1], self.out_features)).transpose(0, 1) |
|
|
|
def train(self, mode: bool = True): |
|
def T(w): |
|
return w.T if self.fan_in_fan_out else w |
|
nn.Linear.train(self, mode) |
|
|
|
|
|
|
|
should = self.merged if mode else not self.merged |
|
|
|
if self.merge_weights and should: |
|
if self.r > 0 and any(self.enable_lora): |
|
delta_w = F.conv1d( |
|
self.lora_A.data.unsqueeze(0), |
|
self.lora_B.data.unsqueeze(-1), |
|
groups=sum(self.enable_lora) |
|
).squeeze(0) |
|
|
|
sign = -1 if mode else 1 |
|
self.weight.data += sign * self.zero_pad(T(delta_w * self.scaling)) |
|
self.merged = not mode |
|
|
|
def forward(self, x: torch.Tensor): |
|
def T(w): |
|
return w.T if self.fan_in_fan_out else w |
|
if self.merged: |
|
return F.linear(x, T(self.weight), bias=self.bias) |
|
else: |
|
result = F.linear(x, T(self.weight), bias=self.bias) |
|
if self.r > 0: |
|
after_A = F.linear(self.lora_dropout(x), self.lora_A) |
|
after_B = F.conv1d( |
|
after_A.transpose(-2, -1), |
|
self.lora_B.unsqueeze(-1), |
|
groups=sum(self.enable_lora) |
|
).transpose(-2, -1) |
|
result += self.zero_pad(after_B) * self.scaling |
|
return result |
|
|
|
|
|
def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None: |
|
for n, p in model.named_parameters(): |
|
if 'lora_' not in n: |
|
p.requires_grad = False |
|
if bias == 'none': |
|
return |
|
elif bias == 'all': |
|
for n, p in model.named_parameters(): |
|
if 'bias' in n: |
|
p.requires_grad = True |
|
elif bias == 'lora_only': |
|
for m in model.modules(): |
|
if isinstance(m, LoRALayer) and \ |
|
hasattr(m, 'bias') and \ |
|
m.bias is not None: |
|
m.bias.requires_grad = True |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]: |
|
my_state_dict = model.state_dict() |
|
if bias == 'none': |
|
return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k} |
|
elif bias == 'all': |
|
return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k} |
|
elif bias == 'lora_only': |
|
to_return = {} |
|
for k in my_state_dict: |
|
if 'lora_' in k: |
|
to_return[k] = my_state_dict[k] |
|
bias_name = k.split('lora_')[0]+'bias' |
|
if bias_name in my_state_dict: |
|
to_return[bias_name] = my_state_dict[bias_name] |
|
return to_return |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
@dataclass |
|
class LoRAConfig: |
|
r: float = 0.0 |
|
alpha: float = 1.0 |
|
dropout: float = 0.0 |
|
|
|
|
|
class CausalSelfAttention(llama.CausalSelfAttention): |
|
lora_config = None |
|
|
|
def __init__(self, config: llama.LLaMAConfig) -> None: |
|
|
|
|
|
nn.Module.__init__(self) |
|
assert config.n_embd % config.n_head == 0 |
|
|
|
|
|
self.c_attn = MergedLinear( |
|
in_features=config.n_embd, |
|
out_features=3 * config.n_embd, |
|
r=self.lora_config.r, |
|
lora_alpha=self.lora_config.alpha, |
|
lora_dropout=self.lora_config.dropout, |
|
enable_lora=[True, False, True], |
|
fan_in_fan_out = False, |
|
merge_weights=True, |
|
bias=False) |
|
|
|
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) |
|
|
|
self.n_head = config.n_head |
|
self.n_embd = config.n_embd |
|
self.block_size = config.block_size |
|
self.rope_cache = None |
|
|
|
|
|
@contextmanager |
|
def lora(r, alpha, dropout, enabled: bool = True): |
|
"""A context manager under which you can instantiate the model with LoRA.""" |
|
if not enabled: |
|
yield |
|
return |
|
|
|
CausalSelfAttention.lora_config = LoRAConfig(r=r, alpha=alpha, dropout=dropout) |
|
|
|
causal_self_attention = llama.CausalSelfAttention |
|
llama.CausalSelfAttention = CausalSelfAttention |
|
yield |
|
llama.CausalSelfAttention = causal_self_attention |
|
|
|
CausalSelfAttention.lora_config = None |
|
|