cwoolee commited on
Commit
279b61f
1 Parent(s): 9e09867

Upload model

Browse files
Files changed (2) hide show
  1. config.json +4 -0
  2. modeling_blast.py +244 -0
config.json CHANGED
@@ -4,6 +4,10 @@
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
 
 
 
 
7
  "blast_num_blocks": [
8
  16
9
  ],
 
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "modeling_blast.BlastLlamaConfig",
9
+ "AutoModelForCausalLM": "modeling_blast.BlastModelForCausalLM"
10
+ },
11
  "blast_num_blocks": [
12
  16
13
  ],
modeling_blast.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from transformers import PretrainedConfig, LlamaConfig, LlamaModel, LlamaForCausalLM
8
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRotaryEmbedding, LlamaRMSNorm
9
+ from typing import List, Union, Tuple
10
+
11
+ from huggingface_hub import PyTorchModelHubMixin
12
+
13
+
14
+ logger = logging.getLogger(__name__)
15
+ logging.basicConfig(level=logging.INFO)
16
+
17
+ class BlastLlamaConfig(LlamaConfig):
18
+ model_type = "blast_llama"
19
+ keys_to_ignore_at_inference = ["blast_decomposed_weight_path"]
20
+ def __init__(
21
+ self,
22
+ target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
23
+ blast_rank={'q_proj': 1024, 'k_proj': 1024, 'v_proj': 1024, 'o_proj': 1024, 'gate_proj': 1488, 'up_proj': 1488, 'down_proj': 1488},
24
+ blast_num_blocks: Union[Union[List, Tuple], int] = 4,
25
+ indices=[i for i in range(32)],
26
+ precompute_matrix=False,
27
+ **kwargs,
28
+ ):
29
+ self.target_modules = target_modules
30
+ self.blast_rank = blast_rank
31
+ self.blast_num_blocks = blast_num_blocks,
32
+ self.indices = indices
33
+ self.precompute_matrix = precompute_matrix
34
+ #self.blast_decomposed_weight_path = blast_decomposed_weight_path
35
+ super().__init__(**kwargs)
36
+
37
+
38
+ def get_parent(model, mn):
39
+ parent_name = ".".join(mn.split(".")[:-1])
40
+ for n, m in model.named_modules():
41
+ if n == parent_name:
42
+ return m
43
+
44
+
45
+ def replace_layers_with_blast(
46
+ model,
47
+ target_modules,
48
+ blast_rank,
49
+ blast_num_blocks,
50
+ indices,
51
+ precompute_matrix=False,
52
+ ):
53
+ for mn, m in model.named_modules():
54
+ if isinstance(m, torch.nn.Linear):
55
+ for tmn in target_modules:
56
+ if tmn in mn:
57
+ layer_idx = int(mn.split(".")[-3])
58
+ if layer_idx not in indices:
59
+ continue
60
+ if isinstance(blast_rank, dict):
61
+ for k in blast_rank.keys():
62
+ if k in mn:
63
+ rank = blast_rank[k]
64
+ break
65
+ elif isinstance(blast_rank, int):
66
+ rank = blast_rank
67
+ elif isinstance(blast_rank, float):
68
+ rank = int(blast_rank * min(m.weight.shape[0], m.weight.shape[1]))
69
+ else:
70
+ raise ValueError(f"blast_rank must have either dict, int, or float type, got: {type(blast_rank)}.")
71
+
72
+ if isinstance(blast_num_blocks, dict):
73
+ for k in blast_rank.keys():
74
+ if k in mn:
75
+ num_blocks = blast_num_blocks[k]
76
+ break
77
+ elif isinstance(blast_num_blocks, int):
78
+ num_blocks = blast_num_blocks
79
+ elif isinstance(blast_num_blocks, tuple):
80
+ num_blocks = blast_num_blocks
81
+ if len(blast_num_blocks) == 1:
82
+ num_blocks = num_blocks[0]
83
+ if isinstance(num_blocks, list):
84
+ num_blocks = num_blocks[0]
85
+ else:
86
+ raise ValueError(f"blast_num_blocks must have either dict, int, or tuple of ints, got: {type(blast_num_blocks)}.")
87
+
88
+ # Load Decomposed BLAST Weights
89
+ new_layer = BlastLinear(
90
+ in_features=m.weight.shape[1],
91
+ out_features=m.weight.shape[0],
92
+ num_blocks=num_blocks,
93
+ rank=rank,
94
+ bias=m.bias is not None,
95
+ device=m.weight.device,
96
+ dtype=m.weight.dtype,
97
+ precompute_matrix=precompute_matrix,
98
+ )
99
+
100
+ parent_module = get_parent(model, mn)
101
+ child_name = mn.split(".")[-1]
102
+ parent_module.add_module(child_name, new_layer)
103
+
104
+ return model
105
+
106
+
107
+
108
+ class BlastLinear(torch.nn.Module):
109
+ def __init__(self,
110
+ in_features: int,
111
+ out_features: int,
112
+ num_blocks: Union[int, Union[List, Tuple]],
113
+ rank=None,
114
+ bias: bool = True,
115
+ device=None,
116
+ dtype=torch.float32,
117
+ precompute_matrix=False,
118
+ ) -> None:
119
+
120
+ super().__init__()
121
+ self.in_features = in_features
122
+ self.out_features = out_features
123
+ if isinstance(num_blocks, int):
124
+ num_blocks=(num_blocks, num_blocks)
125
+ if isinstance(num_blocks[0], list):
126
+ num_blocks[0] = num_blocks[0][0]
127
+ if isinstance(num_blocks[1], list):
128
+ num_blocks[1] = num_blocks[1][0]
129
+ assert len(num_blocks)==2
130
+ assert in_features % num_blocks[1] == 0 and out_features % num_blocks[0] == 0
131
+ self.num_blocks = num_blocks
132
+ self.precompute_matrix = precompute_matrix
133
+
134
+ if rank is None:
135
+ rank = min(in_features, out_features)
136
+ if isinstance(rank, float):
137
+ rank = int(rank * min(in_features, out_features))
138
+
139
+ self.rank = rank
140
+
141
+
142
+ self.B = nn.Parameter(torch.empty(num_blocks[0], out_features // num_blocks[0], rank, device=device, dtype=dtype))
143
+ self.C = nn.Parameter(torch.empty(num_blocks[1], rank, in_features // num_blocks[1], device=device, dtype=dtype))
144
+ self.D = nn.Parameter(torch.empty(num_blocks[0], num_blocks[1], rank, device=device, dtype=dtype))
145
+
146
+
147
+ if bias:
148
+ self.bias = nn.Parameter(torch.empty(out_features, device=device, dtype=dtype))
149
+ else:
150
+ self.register_parameter('bias', None)
151
+ self.rank_score = 0.
152
+
153
+ def get_matrix(self):
154
+ C = self.C.unsqueeze(0) # 1,b2,r,q
155
+ D = self.D.unsqueeze(-1) # b1,b2,r,1
156
+ DC = C*D
157
+ DC = DC.permute(0,1,3,2).reshape(self.num_blocks[0], self.in_features, self.rank) # b1 n r
158
+ B = self.B # b1 p r
159
+ A = torch.bmm(B, DC.transpose(1,2))
160
+ A = A.view(self.out_features, self.in_features)
161
+ return A
162
+
163
+ #@torch.compile
164
+ def forward(self, x : torch.Tensor) -> torch.Tensor:
165
+
166
+ if self.precompute_matrix:
167
+ if self.training:
168
+ self.A = None
169
+ A = self.get_matrix()
170
+ else:
171
+ if not hasattr(self, 'A') or self.A is None:
172
+ self.A = self.get_matrix()
173
+ A = self.A
174
+ out = torch.nn.functional.linear(x, A)
175
+
176
+ else:
177
+
178
+ x_shape = x.shape
179
+ x = x.flatten(0,-2)
180
+
181
+ x = x.view(-1, self.num_blocks[1], x.shape[-1]//self.num_blocks[1]).transpose(0,1)
182
+ y = torch.bmm(x, self.C.transpose(1,2)) # (nb, n, rank)
183
+
184
+ z = y.unsqueeze(0) * self.D.unsqueeze(2)
185
+ z = z.sum(1)
186
+
187
+ out = torch.bmm(z, self.B.transpose(1,2))
188
+ out = out.transpose(0,1).reshape(*(x_shape[:-1] + (self.out_features,)))
189
+
190
+
191
+ if self.bias is not None:
192
+ out += self.bias.to(x.dtype)
193
+ return out
194
+
195
+ def extra_repr(self) -> str:
196
+ return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}, rank={self.rank}, num_blocks={self.num_blocks}'
197
+
198
+
199
+ class BlastLlamaModel(LlamaModel):
200
+ config_class = BlastLlamaConfig
201
+
202
+ def __init__(self, config: BlastLlamaConfig):
203
+ super().__init__(config)
204
+ self.padding_idx = config.pad_token_id
205
+ self.vocab_size = config.vocab_size
206
+
207
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
208
+ self.layers = nn.ModuleList(
209
+ [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
210
+ )
211
+
212
+ logger.info("Replacing Linear Layers to BlastLiner...")
213
+ replace_layers_with_blast(
214
+ self.layers,
215
+ config.target_modules,
216
+ config.blast_rank,
217
+ config.blast_num_blocks,
218
+ config.indices,
219
+ config.precompute_matrix,
220
+ #config.blast_decomposed_weight_path,
221
+ )
222
+ #config.blast_decomposed_weight_path = None
223
+
224
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
225
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
226
+ self.gradient_checkpointing = False
227
+
228
+ # Initialize weights and apply final processing
229
+ self.post_init()
230
+
231
+ class BlastModelForCausalLM(LlamaForCausalLM, PyTorchModelHubMixin):
232
+ config_class = BlastLlamaConfig
233
+
234
+ def __init__(self, config):
235
+ super().__init__(config)
236
+ self.model = BlastLlamaModel(config)
237
+ self.vocab_size = config.vocab_size
238
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
239
+
240
+ # Initialize weights and apply final processing
241
+ self.post_init()
242
+
243
+
244
+