File size: 6,045 Bytes
12001a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
"""Implementation of the paper:
LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention
https://arxiv.org/abs/2303.16199
"""
# mypy: ignore-errors
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
import lit_llama.model as llama
from lit_llama.model import build_rope_cache, apply_rope, RMSNorm, MLP
@dataclass
class LLaMAConfig(llama.LLaMAConfig):
adapter_prompt_length: int = 10
adapter_start_layer: int = 2
class CausalSelfAttention(nn.Module):
"""A modification of `lit_llama.model.CausalSelfAttention` that adds the attention
over the adaption prompt."""
def __init__(self, config: LLaMAConfig, block_idx: int) -> None:
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
if block_idx >= config.adapter_start_layer:
# adapter embedding layer
self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
# gate for adaption
self.gating_factor = torch.nn.Parameter(torch.zeros(1))
self.n_head = config.n_head
self.n_embd = config.n_embd
self.block_size = config.block_size
self.block_idx = block_idx
self.adapter_prompt_length = config.adapter_prompt_length
self.adapter_start_layer = config.adapter_start_layer
self.rope_cache = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
head_size = C // self.n_head
k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
if self.rope_cache is None:
# cache for future forward calls
self.rope_cache = build_rope_cache(
seq_len=self.block_size,
n_elem=self.n_embd // self.n_head,
dtype=x.dtype,
device=x.device,
)
q = apply_rope(q, self.rope_cache)
k = apply_rope(k, self.rope_cache)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
# att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
# att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
# att = F.softmax(att, dim=-1)
# y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
# efficient attention using Flash Attention CUDA kernels
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
if self.block_idx >= self.adapter_start_layer:
prefix = self.adapter_wte.weight.reshape(1, self.adapter_prompt_length, self.n_embd)
aT = prefix.size(1)
_, ak, av = self.c_attn(prefix).split(self.n_embd, dim=2)
ak = ak.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2)
av = av.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2)
amask = torch.ones(q.shape[-2], ak.shape[-2], dtype=torch.bool, device=x.device)
ay = F.scaled_dot_product_attention(q, ak, av, attn_mask=amask, dropout_p=0.0, is_causal=False)
y = y + self.gating_factor * ay
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.c_proj(y)
return y
class Block(nn.Module):
"""The implementation is identical to `lit_llama.model.Block` with the exception that
we replace the attention layer where adaption is implemented."""
def __init__(self, config: LLaMAConfig, block_idx: int) -> None:
super().__init__()
self.rms_1 = RMSNorm(config.n_embd)
self.attn = CausalSelfAttention(config, block_idx)
self.rms_2 = RMSNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.rms_1(x))
x = x + self.mlp(self.rms_2(x))
return x
class LLaMA(llama.LLaMA):
"""The implementation is identical to `lit_llama.model.LLaMA` with the exception that
the `Block` saves the layer index and passes it down to the attention layer."""
def __init__(self, config: LLaMAConfig) -> None:
nn.Module.__init__(self)
assert config.vocab_size is not None
assert config.block_size is not None
self.config = config
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.vocab_size, config.n_embd),
h=nn.ModuleList([Block(config, i) for i in range(config.n_layer)]),
ln_f=RMSNorm(config.n_embd),
)
)
@classmethod
def from_name(cls, name: str):
return cls(LLaMAConfig.from_name(name))
def mark_only_adapter_as_trainable(model: LLaMA) -> None:
"""Sets `requires_grad=False` for all non-adapter weights."""
for name, param in model.named_parameters():
param.requires_grad = "adapter_wte" in name or "gating_factor" in name
def adapter_state_from_state_dict(state_dict: dict) -> dict:
"""Returns the model state dict with only the adapter weights for saving."""
return {name: param for name, param in state_dict.items() if "adapter_wte" in name or "gating_factor" in name}
|