Yunus Serhat Bıçakçı commited on
Commit
927af54
·
1 Parent(s): 39a3c07
llama/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .generation import LLaMA
2
+ from .model import ModelArgs, Transformer, VisionModel
3
+ from .tokenizer import Tokenizer
llama/generation.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+
5
+ from llama.tokenizer import Tokenizer
6
+ from llama.model import Transformer
7
+
8
+
9
+ class LLaMA:
10
+ def __init__(self, model: Transformer, tokenizer: Tokenizer, vision_model = None):
11
+ self.model = model
12
+ self.tokenizer = tokenizer
13
+ self.vision_model = vision_model
14
+
15
+ def generate(
16
+ self,
17
+ prompts: List[str],
18
+ imgs = None,
19
+ max_gen_len: int = 512,
20
+ temperature: float = 0.8,
21
+ top_p: float = 0.95,
22
+ ) -> List[str]:
23
+ bsz = len(prompts)
24
+ params = self.model.params
25
+ assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
26
+
27
+ mode = 'instruct'
28
+ vision_tokens = None
29
+ if imgs is not None and self.vision_model is not None:
30
+ vision_tokens = self.vision_model(imgs)
31
+ mode = 'caption'
32
+
33
+ prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
34
+
35
+ min_prompt_size = min([len(t) for t in prompt_tokens])
36
+ max_prompt_size = max([len(t) for t in prompt_tokens])
37
+
38
+ total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
39
+
40
+ tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long()
41
+ for k, t in enumerate(prompt_tokens):
42
+ tokens[k, : len(t)] = torch.tensor(t).long()
43
+ input_text_mask = tokens != self.tokenizer.pad_id
44
+ start_pos = min_prompt_size
45
+ prev_pos = 0
46
+ for cur_pos in range(start_pos, total_len):
47
+ logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos, vision_tokens, mode)
48
+ if temperature > 0:
49
+ probs = torch.softmax(logits / temperature, dim=-1)
50
+ next_token = sample_top_p(probs, top_p)
51
+ else:
52
+ next_token = torch.argmax(logits, dim=-1)
53
+ next_token = next_token.reshape(-1)
54
+ # only replace token if prompt has already been generated
55
+ next_token = torch.where(
56
+ input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
57
+ )
58
+ tokens[:, cur_pos] = next_token
59
+ prev_pos = cur_pos
60
+
61
+ decoded = []
62
+ for i, t in enumerate(tokens.tolist()):
63
+ # cut to max gen len
64
+ t = t[len(prompt_tokens[i]) : len(prompt_tokens[i]) + max_gen_len]
65
+ # cut to eos tok if any
66
+ try:
67
+ t = t[: t.index(self.tokenizer.eos_id)]
68
+ except ValueError:
69
+ pass
70
+ decoded.append(self.tokenizer.decode(t))
71
+ return decoded
72
+
73
+
74
+ def sample_top_p(probs, p):
75
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
76
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
77
+ mask = probs_sum - probs_sort > p
78
+ probs_sort[mask] = 0.0
79
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
80
+ next_token = torch.multinomial(probs_sort, num_samples=1)
81
+ next_token = torch.gather(probs_idx, -1, next_token)
82
+ return next_token
llama/model.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ from dataclasses import dataclass
3
+ import math
4
+
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+
9
+ import clip
10
+ from timm.models.vision_transformer import Block
11
+
12
+ import fairscale.nn.model_parallel.initialize as fs_init
13
+ from fairscale.nn.model_parallel.layers import (
14
+ ParallelEmbedding,
15
+ RowParallelLinear,
16
+ ColumnParallelLinear,
17
+ )
18
+
19
+
20
+ @dataclass
21
+ class ModelArgs:
22
+ dim: int = 512
23
+ n_layers: int = 8
24
+ n_heads: int = 8
25
+ vocab_size: int = -1 # defined later by tokenizer
26
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
27
+ norm_eps: float = 1e-5
28
+
29
+ max_batch_size: int = 32
30
+ max_seq_len: int = 2048
31
+
32
+ adapter_len: int = 10
33
+ adapter_layer: int = 30
34
+
35
+ cap_adapter_len: int = 10
36
+ cap_adapter_layer: int = 30
37
+ cap_vision_model: str = "ViT-L/14"
38
+ cap_vision_dim: int = 512
39
+ cap_vision_block: int = 2
40
+
41
+
42
+ class RMSNorm(torch.nn.Module):
43
+ def __init__(self, dim: int, eps: float = 1e-6):
44
+ super().__init__()
45
+ self.eps = eps
46
+ self.weight = nn.Parameter(torch.ones(dim))
47
+
48
+ def _norm(self, x):
49
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
50
+
51
+ def forward(self, x):
52
+ output = self._norm(x.float()).type_as(x)
53
+ return output * self.weight
54
+
55
+
56
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
57
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
58
+ t = torch.arange(end, device=freqs.device) # type: ignore
59
+ freqs = torch.outer(t, freqs).float() # type: ignore
60
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
61
+ return freqs_cis
62
+
63
+
64
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
65
+ ndim = x.ndim
66
+ assert 0 <= 1 < ndim
67
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
68
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
69
+ return freqs_cis.view(*shape)
70
+
71
+
72
+ def apply_rotary_emb(
73
+ xq: torch.Tensor,
74
+ xk: torch.Tensor,
75
+ freqs_cis: torch.Tensor,
76
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
77
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
78
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
79
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
80
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
81
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
82
+ return xq_out.type_as(xq), xk_out.type_as(xk)
83
+
84
+
85
+ class Attention(nn.Module):
86
+ def __init__(self, args: ModelArgs):
87
+ super().__init__()
88
+
89
+ self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()
90
+ self.head_dim = args.dim // args.n_heads
91
+
92
+ self.wq = ColumnParallelLinear(
93
+ args.dim,
94
+ args.n_heads * self.head_dim,
95
+ bias=False,
96
+ gather_output=False,
97
+ init_method=lambda x: x,
98
+ )
99
+ self.wk = ColumnParallelLinear(
100
+ args.dim,
101
+ args.n_heads * self.head_dim,
102
+ bias=False,
103
+ gather_output=False,
104
+ init_method=lambda x: x,
105
+ )
106
+ self.wv = ColumnParallelLinear(
107
+ args.dim,
108
+ args.n_heads * self.head_dim,
109
+ bias=False,
110
+ gather_output=False,
111
+ init_method=lambda x: x,
112
+ )
113
+ self.wo = RowParallelLinear(
114
+ args.n_heads * self.head_dim,
115
+ args.dim,
116
+ bias=False,
117
+ input_is_parallel=True,
118
+ init_method=lambda x: x,
119
+ )
120
+
121
+ self.cache_k = torch.zeros(
122
+ (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
123
+ ).cuda()
124
+ self.cache_v = torch.zeros(
125
+ (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
126
+ ).cuda()
127
+ self.gate = torch.nn.Parameter(torch.zeros(1))
128
+
129
+ self.cap_gate = torch.nn.Parameter(torch.zeros(1, self.n_local_heads, 1, 1))
130
+
131
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor],
132
+ adapter=None, mode='instruct'):
133
+ if mode == 'instruct':
134
+ return self.forward_instruct(x, start_pos, freqs_cis, mask, adapter)
135
+ elif mode == 'caption':
136
+ return self.forward_caption(x, start_pos, freqs_cis, mask, adapter)
137
+
138
+ def forward_instruct(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor],
139
+ adapter=None):
140
+ bsz, seqlen, _ = x.shape
141
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
142
+
143
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
144
+ xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
145
+ xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
146
+
147
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
148
+
149
+ self.cache_k = self.cache_k.to(xq)
150
+ self.cache_v = self.cache_v.to(xq)
151
+
152
+ self.cache_k[:bsz, start_pos: start_pos + seqlen] = xk
153
+ self.cache_v[:bsz, start_pos: start_pos + seqlen] = xv
154
+
155
+ keys = self.cache_k[:bsz, : start_pos + seqlen]
156
+ values = self.cache_v[:bsz, : start_pos + seqlen]
157
+
158
+ if adapter is not None:
159
+ adapter_len = adapter.shape[1]
160
+ adapter_k = self.wk(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1)
161
+ adapter_v = self.wv(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1)
162
+ adapter_k = adapter_k.transpose(1, 2)
163
+ adapter_v = adapter_v.transpose(1, 2)
164
+ xq = xq.transpose(1, 2)
165
+ keys = keys.transpose(1, 2)
166
+ values = values.transpose(1, 2)
167
+ scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
168
+ if mask is not None:
169
+ scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen)
170
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
171
+ output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
172
+ if adapter is not None:
173
+ adapter_scores = torch.matmul(xq, adapter_k.transpose(2, 3)) / math.sqrt(self.head_dim)
174
+ adapter_scores = self.gate * F.softmax(adapter_scores.float(), dim=-1).type_as(xq)
175
+ output = output + torch.matmul(adapter_scores, adapter_v)
176
+ output = output.transpose(
177
+ 1, 2
178
+ ).contiguous().view(bsz, seqlen, -1)
179
+
180
+ return self.wo(output)
181
+
182
+ def forward_caption(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor],
183
+ adapter=None):
184
+ bsz, seqlen, _ = x.shape
185
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
186
+
187
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
188
+ xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
189
+ xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
190
+
191
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
192
+
193
+ self.cache_k = self.cache_k.to(xq)
194
+ self.cache_v = self.cache_v.to(xq)
195
+
196
+ self.cache_k[:bsz, start_pos: start_pos + seqlen] = xk
197
+ self.cache_v[:bsz, start_pos: start_pos + seqlen] = xv
198
+
199
+ keys = self.cache_k[:bsz, : start_pos + seqlen]
200
+ values = self.cache_v[:bsz, : start_pos + seqlen]
201
+
202
+ if adapter is not None:
203
+ adapter_len = adapter.shape[1]
204
+ adapter_k = self.wk(adapter).view(bsz, adapter_len, self.n_local_heads, self.head_dim)
205
+ adapter_v = self.wv(adapter).view(bsz, adapter_len, self.n_local_heads, self.head_dim)
206
+ adapter_k = adapter_k.transpose(1, 2)
207
+ adapter_v = adapter_v.transpose(1, 2)
208
+ xq = xq.transpose(1, 2)
209
+ keys = keys.transpose(1, 2)
210
+ values = values.transpose(1, 2)
211
+ scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
212
+ if mask is not None:
213
+ scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen)
214
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
215
+ output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
216
+ if adapter is not None:
217
+ adapter_scores = torch.matmul(xq, adapter_k.transpose(2, 3)) / math.sqrt(self.head_dim)
218
+ adapter_scores = self.cap_gate.tanh() * F.softmax(adapter_scores.float(), dim=-1).type_as(xq)
219
+
220
+ output = output + torch.matmul(adapter_scores, adapter_v)
221
+ output = output.transpose(
222
+ 1, 2
223
+ ).contiguous().view(bsz, seqlen, -1)
224
+
225
+ return self.wo(output)
226
+
227
+
228
+ class FeedForward(nn.Module):
229
+ def __init__(
230
+ self,
231
+ dim: int,
232
+ hidden_dim: int,
233
+ multiple_of: int,
234
+ ):
235
+ super().__init__()
236
+ hidden_dim = int(2 * hidden_dim / 3)
237
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
238
+
239
+ self.w1 = ColumnParallelLinear(
240
+ dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
241
+ )
242
+ self.w2 = RowParallelLinear(
243
+ hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
244
+ )
245
+ self.w3 = ColumnParallelLinear(
246
+ dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
247
+ )
248
+
249
+ def forward(self, x):
250
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
251
+
252
+
253
+ class TransformerBlock(nn.Module):
254
+ def __init__(self, layer_id: int, args: ModelArgs):
255
+ super().__init__()
256
+ self.n_heads = args.n_heads
257
+ self.dim = args.dim
258
+ self.head_dim = args.dim // args.n_heads
259
+ self.attention = Attention(args)
260
+ self.feed_forward = FeedForward(
261
+ dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of
262
+ )
263
+ self.layer_id = layer_id
264
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
265
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
266
+
267
+ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor],
268
+ adapter=None, mode='instruct'):
269
+ h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask, adapter, mode=mode)
270
+ out = h + self.feed_forward.forward(self.ffn_norm(h))
271
+ return out
272
+
273
+
274
+ class Transformer(nn.Module):
275
+ def __init__(self, params: ModelArgs):
276
+ super().__init__()
277
+ self.params = params
278
+ self.vocab_size = params.vocab_size
279
+ self.n_layers = params.n_layers
280
+
281
+ self.tok_embeddings = ParallelEmbedding(
282
+ params.vocab_size, params.dim, init_method=lambda x: x
283
+ )
284
+
285
+ self.layers = torch.nn.ModuleList()
286
+ for layer_id in range(params.n_layers):
287
+ self.layers.append(TransformerBlock(layer_id, params))
288
+
289
+ self.norm = RMSNorm(params.dim, eps=params.norm_eps)
290
+ self.output = ColumnParallelLinear(
291
+ params.dim, params.vocab_size, bias=False, init_method=lambda x: x
292
+ )
293
+
294
+ self.freqs_cis = precompute_freqs_cis(
295
+ self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
296
+ )
297
+
298
+ # Note: this is only a preview of multimodal LLaMA-Adapter
299
+ # and requires more efforts to decouple LLaMA-Adapter from LLaMA.
300
+ # instruct model
301
+ self.adapter_query = nn.Embedding(params.adapter_len * params.adapter_layer, params.dim)
302
+ self.adapter_len = params.adapter_len
303
+ self.adapter_layer = params.adapter_layer
304
+
305
+ # caption model
306
+ self.cap_adapter_query = nn.Embedding(params.cap_adapter_len * params.cap_adapter_layer, params.dim)
307
+ self.cap_adapter_len = params.cap_adapter_len
308
+ self.cap_adapter_layer = params.cap_adapter_layer
309
+
310
+ @torch.inference_mode()
311
+ def forward(self, tokens: torch.Tensor, start_pos: int, visual_tokens: torch.Tensor = None, mode: str = 'instruct'):
312
+ if mode == 'instruct':
313
+ return self.forward_instruct(tokens, start_pos, mode)
314
+ elif mode == 'caption':
315
+ return self.forward_caption(tokens, start_pos, visual_tokens, mode)
316
+
317
+ def forward_instruct(self, tokens: torch.Tensor, start_pos: int, mode=None):
318
+ _bsz, seqlen = tokens.shape
319
+ h = self.tok_embeddings(tokens)
320
+ self.freqs_cis = self.freqs_cis.to(h.device)
321
+ freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen]
322
+ adapter = self.adapter_query.weight.reshape(self.params.adapter_layer, self.params.adapter_len,
323
+ self.params.dim).unsqueeze(1)
324
+ mask = None
325
+ if seqlen > 1:
326
+ mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
327
+ mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
328
+
329
+ for layer in self.layers[: -1 * self.params.adapter_layer]:
330
+ h = layer(h, start_pos, freqs_cis, mask)
331
+ layer_index = 0
332
+ for layer in self.layers[-1 * self.params.adapter_layer:]:
333
+ h = layer(h, start_pos, freqs_cis, mask, adapter[layer_index], mode=mode)
334
+ layer_index = layer_index + 1
335
+ h = self.norm(h)
336
+ output = self.output(h[:, -1, :]) # only compute last logits
337
+ return output.float()
338
+
339
+ def forward_caption(self, tokens: torch.Tensor, start_pos: int, visual_tokens: torch.Tensor = None, mode=None):
340
+ _bsz, seqlen = tokens.shape
341
+ h = self.tok_embeddings(tokens)
342
+ self.freqs_cis = self.freqs_cis.to(h.device)
343
+ freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen]
344
+ adapter = self.cap_adapter_query.weight.reshape(self.params.cap_adapter_layer, self.params.cap_adapter_len,
345
+ self.params.dim).unsqueeze(1)
346
+ mask = None
347
+ if seqlen > 1:
348
+ mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
349
+ mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
350
+
351
+ for layer in self.layers[: -1 * self.params.cap_adapter_layer]:
352
+ h = layer(h, start_pos, freqs_cis, mask)
353
+ layer_index = 0
354
+ for layer in self.layers[-1 * self.params.cap_adapter_layer:]:
355
+ adapter_per_layer = adapter[layer_index]
356
+ if visual_tokens is not None:
357
+ adapter_per_layer = adapter_per_layer + visual_tokens
358
+ h = layer(h, start_pos, freqs_cis, mask, adapter_per_layer, mode=mode)
359
+ layer_index = layer_index + 1
360
+ h = self.norm(h)
361
+ output = self.output(h[:, -1, :]) # only compute last logits
362
+ return output.float()
363
+
364
+
365
+ class VisionModel(nn.Module):
366
+ def __init__(self, params: ModelArgs):
367
+ super().__init__()
368
+
369
+ self.params = params
370
+
371
+ self.clip, self.clip_transform = clip.load(params.cap_vision_model)
372
+ self.clip.float()
373
+ for param in self.clip.parameters():
374
+ param.requires_grad = False
375
+
376
+ self.clip_proj = nn.Linear(self.clip.visual.output_dim, params.cap_vision_dim)
377
+ self.clip_proj_norm = nn.LayerNorm(params.cap_vision_dim)
378
+
379
+ self.visual_query = nn.Embedding(params.cap_adapter_len, params.cap_vision_dim)
380
+
381
+ self.visual_blocks = nn.ModuleList([
382
+ Block(params.cap_vision_dim, 16, 4, qkv_bias=True, qk_scale=None, norm_layer=nn.LayerNorm)
383
+ for i in range(params.cap_vision_block)])
384
+
385
+ self.visual_proj = nn.Linear(params.cap_vision_dim, params.dim)
386
+ self.visual_proj_norm = nn.LayerNorm(params.dim)
387
+
388
+ def clip_encode_image(self, x):
389
+ x = self.clip.visual.conv1(x) # shape = [*, width, grid, grid]
390
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
391
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
392
+ x = torch.cat([self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1],
393
+ dtype=x.dtype, device=x.device), x],
394
+ dim=1) # shape = [*, grid ** 2 + 1, width]
395
+ x = x + self.clip.visual.positional_embedding.to(x.dtype)
396
+ x = self.clip.visual.ln_pre(x)
397
+
398
+ x = x.permute(1, 0, 2) # NLD -> LND
399
+ x = self.clip.visual.transformer(x)
400
+ x = x.permute(1, 0, 2) # LND -> NLD
401
+
402
+ x = self.clip.visual.ln_post(x[:, :, :])
403
+
404
+ if self.clip.visual.proj is not None:
405
+ x = x @ self.clip.visual.proj
406
+
407
+ return x
408
+
409
+ def forward(self, imgs):
410
+ x = [self.clip_transform(img) for img in imgs]
411
+ x = torch.stack(x, dim=0).to(self.visual_query.weight.device)
412
+ _bsz = x.shape[0]
413
+
414
+ visual_feats = self.clip_encode_image(x).half()
415
+ visual_feats = self.clip_proj_norm(self.clip_proj(visual_feats))
416
+ visual_query = self.visual_query.weight.unsqueeze(0).repeat(_bsz, 1, 1)
417
+ visual_query = torch.cat([visual_query, visual_feats], dim=1)
418
+ for block in self.visual_blocks:
419
+ visual_query = block(visual_query)
420
+ visual_query = visual_query[:, :self.params.cap_adapter_len, :]
421
+ visual_query = self.visual_proj(visual_query)
422
+ visual_query = self.visual_proj_norm(visual_query)
423
+
424
+ return visual_query
llama/tokenizer.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentencepiece import SentencePieceProcessor
2
+ from logging import getLogger
3
+ from typing import List
4
+ import os
5
+
6
+
7
+ logger = getLogger()
8
+
9
+
10
+ class Tokenizer:
11
+ def __init__(self, model_path: str):
12
+ # reload tokenizer
13
+ assert os.path.isfile(model_path), model_path
14
+ self.sp_model = SentencePieceProcessor(model_file=model_path)
15
+ logger.info(f"Reloaded SentencePiece model from {model_path}")
16
+
17
+ # BOS / EOS token IDs
18
+ self.n_words: int = self.sp_model.vocab_size()
19
+ self.bos_id: int = self.sp_model.bos_id()
20
+ self.eos_id: int = self.sp_model.eos_id()
21
+ self.pad_id: int = self.sp_model.pad_id()
22
+ logger.info(
23
+ f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
24
+ )
25
+ assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
26
+
27
+ def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
28
+ assert type(s) is str
29
+ t = self.sp_model.encode(s)
30
+ if bos:
31
+ t = [self.bos_id] + t
32
+ if eos:
33
+ t = t + [self.eos_id]
34
+ return t
35
+
36
+ def decode(self, t: List[int]) -> str:
37
+ return self.sp_model.decode(t)
pages/3_📋_Types.py CHANGED
@@ -9,6 +9,7 @@ import geopandas as gpd
9
  import streamlit as st
10
  import leafmap.colormaps as cm
11
  from leafmap.common import hex_to_rgb
 
12
 
13
 
14
  st.set_page_config(layout="wide")
@@ -67,7 +68,7 @@ def app():
67
 
68
 
69
  with row1_col1:
70
- m = leafmap.Map(center=(51.50, -0.1]), zoom=10)
71
  m.add_geojson(borough, layer_name='London Boroughs')
72
  # if layers is not None:
73
  # for layer in layers:
 
9
  import streamlit as st
10
  import leafmap.colormaps as cm
11
  from leafmap.common import hex_to_rgb
12
+ import leafmap.foliumap as leafmap
13
 
14
 
15
  st.set_page_config(layout="wide")
 
68
 
69
 
70
  with row1_col1:
71
+ m = leafmap.Map(center=(51.50, -0.1), zoom=10)
72
  m.add_geojson(borough, layer_name='London Boroughs')
73
  # if layers is not None:
74
  # for layer in layers:
pages/4_LLM.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import glob
4
+ import sys
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Tuple
8
+
9
+ from huggingface_hub import hf_hub_download
10
+ from PIL import Image
11
+ import gradio as gr
12
+ import torch
13
+ from fairscale.nn.model_parallel.initialize import initialize_model_parallel
14
+
15
+ from llama import LLaMA, ModelArgs, Tokenizer, Transformer, VisionModel
16
+
17
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
18
+
19
+ PROMPT_DICT = {
20
+ "prompt_input": (
21
+ "Below is an instruction that describes a task, paired with an input that provides further context. "
22
+ "Write a response that appropriately completes the request.\n\n"
23
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
24
+ ),
25
+ "prompt_no_input": (
26
+ "Below is an instruction that describes a task. "
27
+ "Write a response that appropriately completes the request.\n\n"
28
+ "### Instruction:\n{instruction}\n\n### Response:"
29
+ ),
30
+ }
31
+
32
+
33
+ def setup_model_parallel() -> Tuple[int, int]:
34
+ os.environ['RANK'] = '0'
35
+ os.environ['WORLD_SIZE'] = '1'
36
+ os.environ['MP'] = '1'
37
+ os.environ['MASTER_ADDR'] = '127.0.0.1'
38
+ os.environ['MASTER_PORT'] = '2223'
39
+ local_rank = int(os.environ.get("LOCAL_RANK", -1))
40
+ world_size = int(os.environ.get("WORLD_SIZE", -1))
41
+
42
+ torch.distributed.init_process_group("nccl")
43
+ initialize_model_parallel(world_size)
44
+ torch.cuda.set_device(local_rank)
45
+
46
+ # seed must be the same in all processes
47
+ torch.manual_seed(1)
48
+ return local_rank, world_size
49
+
50
+
51
+ def load(
52
+ ckpt0_path: str,
53
+ ckpt1_path: str,
54
+ param_path: str,
55
+ tokenizer_path: str,
56
+ instruct_adapter_path: str,
57
+ caption_adapter_path: str,
58
+ local_rank: int,
59
+ world_size: int,
60
+ max_seq_len: int,
61
+ max_batch_size: int,
62
+ ) -> LLaMA:
63
+ start_time = time.time()
64
+ print("Loading")
65
+ instruct_adapter_checkpoint = torch.load(
66
+ instruct_adapter_path, map_location="cpu")
67
+ caption_adapter_checkpoint = torch.load(
68
+ caption_adapter_path, map_location="cpu")
69
+ with open(param_path, "r") as f:
70
+ params = json.loads(f.read())
71
+
72
+ model_args: ModelArgs = ModelArgs(
73
+ max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
74
+ )
75
+ model_args.adapter_layer = int(
76
+ instruct_adapter_checkpoint['adapter_query.weight'].shape[0] / model_args.adapter_len)
77
+ model_args.cap_adapter_layer = int(
78
+ caption_adapter_checkpoint['cap_adapter_query.weight'].shape[0] / model_args.cap_adapter_len)
79
+
80
+ tokenizer = Tokenizer(model_path=tokenizer_path)
81
+ model_args.vocab_size = tokenizer.n_words
82
+ torch.set_default_tensor_type(torch.cuda.HalfTensor)
83
+ model = Transformer(model_args)
84
+
85
+ # To reduce memory usuage
86
+ ckpt0 = torch.load(ckpt0_path, map_location='cuda')
87
+ model.load_state_dict(ckpt0, strict=False)
88
+ del ckpt0
89
+ torch.cuda.empty_cache()
90
+
91
+ ckpt1 = torch.load(ckpt1_path, map_location='cuda')
92
+ model.load_state_dict(ckpt1, strict=False)
93
+ del ckpt1
94
+ torch.cuda.empty_cache()
95
+
96
+ vision_model = VisionModel(model_args)
97
+
98
+ torch.set_default_tensor_type(torch.FloatTensor)
99
+ model.load_state_dict(instruct_adapter_checkpoint, strict=False)
100
+ model.load_state_dict(caption_adapter_checkpoint, strict=False)
101
+ vision_model.load_state_dict(caption_adapter_checkpoint, strict=False)
102
+
103
+ generator = LLaMA(model, tokenizer, vision_model)
104
+ print(f"Loaded in {time.time() - start_time:.2f} seconds")
105
+ return generator
106
+
107
+
108
+ def instruct_generate(
109
+ instruct: str,
110
+ input: str = 'none',
111
+ max_gen_len=512,
112
+ temperature: float = 0.1,
113
+ top_p: float = 0.75,
114
+ ):
115
+ if input == 'none':
116
+ prompt = PROMPT_DICT['prompt_no_input'].format_map(
117
+ {'instruction': instruct, 'input': ''})
118
+ else:
119
+ prompt = PROMPT_DICT['prompt_input'].format_map(
120
+ {'instruction': instruct, 'input': input})
121
+
122
+ results = generator.generate(
123
+ [prompt], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p
124
+ )
125
+ result = results[0].strip()
126
+ print(result)
127
+ return result
128
+
129
+
130
+ def caption_generate(
131
+ img: str,
132
+ max_gen_len=512,
133
+ temperature: float = 0.1,
134
+ top_p: float = 0.75,
135
+ ):
136
+ imgs = [Image.open(img).convert('RGB')]
137
+ prompts = ["Generate caption of this image :",] * len(imgs)
138
+
139
+ results = generator.generate(
140
+ prompts, imgs=imgs, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p
141
+ )
142
+ result = results[0].strip()
143
+ print(result)
144
+ return result
145
+
146
+
147
+ def download_llama_adapter(instruct_adapter_path, caption_adapter_path):
148
+ if not os.path.exists(instruct_adapter_path):
149
+ os.system(
150
+ f"wget -q -O {instruct_adapter_path} https://github.com/ZrrSkywalker/LLaMA-Adapter/releases/download/v.1.0.0/llama_adapter_len10_layer30_release.pth")
151
+
152
+ if not os.path.exists(caption_adapter_path):
153
+ os.system(
154
+ f"wget -q -O {caption_adapter_path} https://github.com/ZrrSkywalker/LLaMA-Adapter/releases/download/v.1.0.0/llama_adapter_len10_layer30_caption_vit_l.pth")
155
+
156
+
157
+ # ckpt_path = "/data1/llma/7B/consolidated.00.pth"
158
+ # param_path = "/data1/llma/7B/params.json"
159
+ # tokenizer_path = "/data1/llma/tokenizer.model"
160
+ ckpt0_path = hf_hub_download(
161
+ repo_id="csuhan/llama_storage", filename="consolidated.00_part0.pth")
162
+ ckpt1_path = hf_hub_download(
163
+ repo_id="csuhan/llama_storage", filename="consolidated.00_part1.pth")
164
+ param_path = hf_hub_download(
165
+ repo_id="nyanko7/LLaMA-7B", filename="params.json")
166
+ tokenizer_path = hf_hub_download(
167
+ repo_id="nyanko7/LLaMA-7B", filename="tokenizer.model")
168
+ instruct_adapter_path = "llama_adapter_len10_layer30_release.pth"
169
+ caption_adapter_path = "llama_adapter_len10_layer30_caption_vit_l.pth"
170
+ max_seq_len = 512
171
+ max_batch_size = 1
172
+
173
+ # download models
174
+ # download_llama_adapter(instruct_adapter_path, caption_adapter_path)
175
+
176
+ local_rank, world_size = setup_model_parallel()
177
+ if local_rank > 0:
178
+ sys.stdout = open(os.devnull, "w")
179
+
180
+ generator = load(
181
+ ckpt0_path, ckpt1_path, param_path, tokenizer_path, instruct_adapter_path, caption_adapter_path, local_rank, world_size, max_seq_len, max_batch_size
182
+ )
183
+
184
+
185
+ def create_instruct_demo():
186
+ with gr.Blocks() as instruct_demo:
187
+ with gr.Row():
188
+ with gr.Column():
189
+ instruction = gr.Textbox(lines=2, label="Instruction")
190
+ input = gr.Textbox(
191
+ lines=2, label="Context input", placeholder='none')
192
+ max_len = gr.Slider(minimum=1, maximum=512,
193
+ value=128, label="Max length")
194
+ with gr.Accordion(label='Advanced options', open=False):
195
+ temp = gr.Slider(minimum=0, maximum=1,
196
+ value=0.1, label="Temperature")
197
+ top_p = gr.Slider(minimum=0, maximum=1,
198
+ value=0.75, label="Top p")
199
+
200
+ run_botton = gr.Button("Run")
201
+
202
+ with gr.Column():
203
+ outputs = gr.Textbox(lines=10, label="Output")
204
+
205
+ inputs = [instruction, input, max_len, temp, top_p]
206
+
207
+ examples = [
208
+ "Tell me about alpacas.",
209
+ "Write a Python program that prints the first 10 Fibonacci numbers.",
210
+ "Write a conversation between the sun and pluto.",
211
+ "Write a theory to explain why cat never existed",
212
+ ]
213
+ examples = [
214
+ [x, "none", 128, 0.1, 0.75]
215
+ for x in examples]
216
+
217
+ gr.Examples(
218
+ examples=examples,
219
+ inputs=inputs,
220
+ outputs=outputs,
221
+ fn=instruct_generate,
222
+ cache_examples=os.getenv('SYSTEM') == 'spaces'
223
+ )
224
+ run_botton.click(fn=instruct_generate, inputs=inputs, outputs=outputs)
225
+ return instruct_demo
226
+
227
+
228
+ def create_caption_demo():
229
+ with gr.Blocks() as instruct_demo:
230
+ with gr.Row():
231
+ with gr.Column():
232
+ img = gr.Image(label='Input', type='filepath')
233
+ max_len = gr.Slider(minimum=1, maximum=512,
234
+ value=64, label="Max length")
235
+ with gr.Accordion(label='Advanced options', open=False):
236
+ temp = gr.Slider(minimum=0, maximum=1,
237
+ value=0.1, label="Temperature")
238
+ top_p = gr.Slider(minimum=0, maximum=1,
239
+ value=0.75, label="Top p")
240
+
241
+ run_botton = gr.Button("Run")
242
+
243
+ with gr.Column():
244
+ outputs = gr.Textbox(lines=10, label="Output")
245
+
246
+ inputs = [img, max_len, temp, top_p]
247
+
248
+ examples = glob.glob("caption_demo/*.jpg")
249
+ examples = [
250
+ [x, 64, 0.1, 0.75]
251
+ for x in examples]
252
+
253
+ gr.Examples(
254
+ examples=examples,
255
+ inputs=inputs,
256
+ outputs=outputs,
257
+ fn=caption_generate,
258
+ cache_examples=os.getenv('SYSTEM') == 'spaces'
259
+ )
260
+ run_botton.click(fn=caption_generate, inputs=inputs, outputs=outputs)
261
+ return instruct_demo
262
+
263
+
264
+ description = """
265
+ # LLaMA-Adapter🚀
266
+ The official demo for **LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention**.
267
+ Please refer to our [arXiv paper](https://arxiv.org/abs/2303.16199) and [github](https://github.com/ZrrSkywalker/LLaMA-Adapter) for more details.
268
+ """
269
+
270
+ with gr.Blocks(css='style.css') as demo:
271
+ gr.Markdown(description)
272
+ with gr.TabItem("Instruction-Following"):
273
+ create_instruct_demo()
274
+ with gr.TabItem("Image Captioning"):
275
+ create_caption_demo()
276
+
277
+ demo.queue(api_open=True, concurrency_count=1).launch()
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  --find-links=https://girder.github.io/large_image_wheels GDAL
 
2
  # cartopy
3
  folium==0.13.0
4
  # ipywidgets<8.0.5
@@ -22,4 +23,10 @@ streamlit-extras
22
  hugchat
23
  # git+https://github.com/giswqs/leafmap
24
  # git+https://github.com/giswqs/geemap
25
-
 
 
 
 
 
 
 
1
  --find-links=https://girder.github.io/large_image_wheels GDAL
2
+ --extra-index-url https://download.pytorch.org/whl/cu113
3
  # cartopy
4
  folium==0.13.0
5
  # ipywidgets<8.0.5
 
23
  hugchat
24
  # git+https://github.com/giswqs/leafmap
25
  # git+https://github.com/giswqs/geemap
26
+ torch==1.12.0+cu113
27
+ fairscale
28
+ sentencepiece
29
+ Pillow
30
+ huggingface_hub
31
+ git+https://github.com/csuhan/timm_0_3_2.git
32
+ git+https://github.com/openai/CLIP.git