chansung commited on
Commit
f5e2f31
1 Parent(s): 91d916b

Create model.py

Browse files
Files changed (1) hide show
  1. llama/model.py +238 -0
llama/model.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3
+
4
+ from typing import Optional, Tuple
5
+ from dataclasses import dataclass
6
+ import math
7
+
8
+ import torch
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+
12
+ import fairscale.nn.model_parallel.initialize as fs_init
13
+ from fairscale.nn.model_parallel.layers import (
14
+ ParallelEmbedding,
15
+ RowParallelLinear,
16
+ ColumnParallelLinear,
17
+ )
18
+
19
+
20
+ @dataclass
21
+ class ModelArgs:
22
+ dim: int = 512
23
+ n_layers: int = 8
24
+ n_heads: int = 8
25
+ vocab_size: int = -1 # defined later by tokenizer
26
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
27
+ norm_eps: float = 1e-5
28
+
29
+ max_batch_size: int = 32
30
+ max_seq_len: int = 1024
31
+
32
+
33
+ class RMSNorm(torch.nn.Module):
34
+ def __init__(self, dim: int, eps: float = 1e-6):
35
+ super().__init__()
36
+ self.eps = eps
37
+ self.weight = nn.Parameter(torch.ones(dim))
38
+
39
+ def _norm(self, x):
40
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
41
+
42
+ def forward(self, x):
43
+ output = self._norm(x.float()).type_as(x)
44
+ return output * self.weight
45
+
46
+
47
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
48
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
49
+ t = torch.arange(end, device=freqs.device) # type: ignore
50
+ freqs = torch.outer(t, freqs).float() # type: ignore
51
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
52
+ return freqs_cis
53
+
54
+
55
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
56
+ ndim = x.ndim
57
+ assert 0 <= 1 < ndim
58
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
59
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
60
+ return freqs_cis.view(*shape)
61
+
62
+
63
+ def apply_rotary_emb(
64
+ xq: torch.Tensor,
65
+ xk: torch.Tensor,
66
+ freqs_cis: torch.Tensor,
67
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
68
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
69
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
70
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
71
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
72
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
73
+ return xq_out.type_as(xq), xk_out.type_as(xk)
74
+
75
+
76
+ class Attention(nn.Module):
77
+ def __init__(self, args: ModelArgs):
78
+ super().__init__()
79
+
80
+ self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()
81
+ self.head_dim = args.dim // args.n_heads
82
+
83
+ self.wq = ColumnParallelLinear(
84
+ args.dim,
85
+ args.n_heads * self.head_dim,
86
+ bias=False,
87
+ gather_output=False,
88
+ init_method=lambda x: x,
89
+ )
90
+ self.wk = ColumnParallelLinear(
91
+ args.dim,
92
+ args.n_heads * self.head_dim,
93
+ bias=False,
94
+ gather_output=False,
95
+ init_method=lambda x: x,
96
+ )
97
+ self.wv = ColumnParallelLinear(
98
+ args.dim,
99
+ args.n_heads * self.head_dim,
100
+ bias=False,
101
+ gather_output=False,
102
+ init_method=lambda x: x,
103
+ )
104
+ self.wo = RowParallelLinear(
105
+ args.n_heads * self.head_dim,
106
+ args.dim,
107
+ bias=False,
108
+ input_is_parallel=True,
109
+ init_method=lambda x: x,
110
+ )
111
+
112
+ self.cache_k = torch.zeros(
113
+ (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
114
+ ).cuda()
115
+ self.cache_v = torch.zeros(
116
+ (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
117
+ ).cuda()
118
+
119
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
120
+ bsz, seqlen, _ = x.shape
121
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
122
+
123
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
124
+ xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
125
+ xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
126
+
127
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
128
+
129
+ self.cache_k = self.cache_k.to(xq)
130
+ self.cache_v = self.cache_v.to(xq)
131
+
132
+ self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
133
+ self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
134
+
135
+ keys = self.cache_k[:bsz, : start_pos + seqlen]
136
+ values = self.cache_v[:bsz, : start_pos + seqlen]
137
+
138
+ xq = xq.transpose(1, 2)
139
+ keys = keys.transpose(1, 2)
140
+ values = values.transpose(1, 2)
141
+ scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
142
+ if mask is not None:
143
+ scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen)
144
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
145
+ output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
146
+ output = output.transpose(
147
+ 1, 2
148
+ ).contiguous().view(bsz, seqlen, -1)
149
+
150
+ return self.wo(output)
151
+
152
+
153
+ class FeedForward(nn.Module):
154
+ def __init__(
155
+ self,
156
+ dim: int,
157
+ hidden_dim: int,
158
+ multiple_of: int,
159
+ ):
160
+ super().__init__()
161
+ hidden_dim = int(2 * hidden_dim / 3)
162
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
163
+
164
+ self.w1 = ColumnParallelLinear(
165
+ dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
166
+ )
167
+ self.w2 = RowParallelLinear(
168
+ hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
169
+ )
170
+ self.w3 = ColumnParallelLinear(
171
+ dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
172
+ )
173
+
174
+ def forward(self, x):
175
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
176
+
177
+
178
+ class TransformerBlock(nn.Module):
179
+ def __init__(self, layer_id: int, args: ModelArgs):
180
+ super().__init__()
181
+ self.n_heads = args.n_heads
182
+ self.dim = args.dim
183
+ self.head_dim = args.dim // args.n_heads
184
+ self.attention = Attention(args)
185
+ self.feed_forward = FeedForward(
186
+ dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of
187
+ )
188
+ self.layer_id = layer_id
189
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
190
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
191
+
192
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
193
+ h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)
194
+ out = h + self.feed_forward.forward(self.ffn_norm(h))
195
+ return out
196
+
197
+
198
+ class Transformer(nn.Module):
199
+ def __init__(self, params: ModelArgs):
200
+ super().__init__()
201
+ self.params = params
202
+ self.vocab_size = params.vocab_size
203
+ self.n_layers = params.n_layers
204
+
205
+ self.tok_embeddings = ParallelEmbedding(
206
+ params.vocab_size, params.dim, init_method=lambda x: x
207
+ )
208
+
209
+ self.layers = torch.nn.ModuleList()
210
+ for layer_id in range(params.n_layers):
211
+ self.layers.append(TransformerBlock(layer_id, params))
212
+
213
+ self.norm = RMSNorm(params.dim, eps=params.norm_eps)
214
+ self.output = ColumnParallelLinear(
215
+ params.dim, params.vocab_size, bias=False, init_method=lambda x: x
216
+ )
217
+
218
+ self.freqs_cis = precompute_freqs_cis(
219
+ self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
220
+ )
221
+
222
+ @torch.inference_mode()
223
+ def forward(self, tokens: torch.Tensor, start_pos: int):
224
+ _bsz, seqlen = tokens.shape
225
+ h = self.tok_embeddings(tokens)
226
+ self.freqs_cis = self.freqs_cis.to(h.device)
227
+ freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
228
+
229
+ mask = None
230
+ if seqlen > 1:
231
+ mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
232
+ mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
233
+
234
+ for layer in self.layers:
235
+ h = layer(h, start_pos, freqs_cis, mask)
236
+ h = self.norm(h)
237
+ output = self.output(h[:, -1, :]) # only compute last logits
238
+ return output.float()