root commited on
Commit
67922e4
1 Parent(s): 649bc8a

update sr tp modeling

Browse files
Files changed (1) hide show
  1. sr_tp_modeling.py +890 -0
sr_tp_modeling.py ADDED
@@ -0,0 +1,890 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch SRV1 model."""
2
+ import sys
3
+ import os
4
+ from os import path
5
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
6
+ print(sys.path)
7
+ import math
8
+ from typing import List, Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.utils.checkpoint
12
+ from torch import nn
13
+ from torch.nn import CrossEntropyLoss
14
+ from transformers.activations import ACT2FN
15
+ from transformers import AutoTokenizer, AutoConfig
16
+ from .configuration_srv1 import SRV1Config
17
+
18
+ from transformers.modeling_outputs import (
19
+ BaseModelOutputWithPast,
20
+ CausalLMOutputWithPast,
21
+ )
22
+ from transformers.modeling_utils import PreTrainedModel
23
+ from transformers.utils import (
24
+ add_start_docstrings,
25
+ add_start_docstrings_to_model_forward,
26
+ logging,
27
+ replace_return_docstrings,
28
+ )
29
+
30
+ from .layers import (
31
+ TensorParallelColumnLinear,
32
+ TensorParallelEmbedding,
33
+ TensorParallelHead,
34
+ TensorParallelRowLinear,
35
+ load_layer_norm_no_bias,
36
+ )
37
+ from .dist import initialize_torch_distributed
38
+ from .weights import Weights
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ _CONFIG_FOR_DOC = SRV1Config
43
+
44
+
45
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
46
+ def _make_causal_mask(
47
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
48
+ ):
49
+ """
50
+ Make causal mask used for bi-directional self-attention.
51
+ """
52
+ bsz, tgt_len = input_ids_shape
53
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
54
+ mask_cond = torch.arange(mask.size(-1), device=device)
55
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
56
+ mask = mask.to(dtype)
57
+
58
+ if past_key_values_length > 0:
59
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
60
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
61
+
62
+
63
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
64
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
65
+ """
66
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
67
+ """
68
+ bsz, src_len = mask.size()
69
+ tgt_len = tgt_len if tgt_len is not None else src_len
70
+
71
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
72
+
73
+ inverted_mask = 1.0 - expanded_mask
74
+
75
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
76
+
77
+
78
+ class SRV1RMSNorm(nn.Module):
79
+ def __init__(self, hidden_size, eps=1e-6):
80
+ """
81
+ SRV1RMSNorm is equivalent to T5LayerNorm
82
+ """
83
+ super().__init__()
84
+ self.weight = nn.Parameter(torch.ones(hidden_size))
85
+ self.variance_epsilon = eps
86
+
87
+ def forward(self, hidden_states):
88
+ input_dtype = hidden_states.dtype
89
+ hidden_states = hidden_states.to(torch.float32)
90
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
91
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
92
+ return self.weight * hidden_states.to(input_dtype)
93
+
94
+
95
+ SRV1RMSNorm.load_no_bias = load_layer_norm_no_bias
96
+
97
+
98
+ class SRV1RotaryEmbedding(torch.nn.Module):
99
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
100
+ super().__init__()
101
+
102
+ self.dim = dim
103
+ self.max_position_embeddings = max_position_embeddings
104
+ self.base = base
105
+ self.inv_freq = self._create_inv_freq(dim=dim, base=base, device=device)
106
+
107
+ # Build here to make `torch.jit.trace` work.
108
+ self._set_cos_sin_cache(
109
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
110
+ )
111
+
112
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
113
+ self.max_seq_len_cached = seq_len
114
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
115
+
116
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
117
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
118
+ emb = torch.cat((freqs, freqs), dim=-1)
119
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
120
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
121
+
122
+ def forward(self, x, seq_len=None):
123
+ # x: [bs, num_attention_heads, seq_len, head_size]
124
+ if seq_len > self.max_seq_len_cached:
125
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
126
+
127
+ return (
128
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
129
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
130
+ )
131
+
132
+ def _create_inv_freq(self, dim, base, device):
133
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
134
+ return inv_freq
135
+
136
+ class SRV1RotaryEmbedding(SRV1RotaryEmbedding):
137
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
138
+ self.scaling_factor = scaling_factor
139
+ super().__init__(dim, max_position_embeddings, base, device)
140
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
141
+ self.max_seq_len_cached = seq_len
142
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
143
+ t = t / self.scaling_factor
144
+
145
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
146
+
147
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
148
+ emb = torch.cat((freqs, freqs), dim=-1)
149
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
150
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
151
+
152
+ def rotate_half(x):
153
+ """Rotates half the hidden dims of the input."""
154
+ x1 = x[..., : x.shape[-1] // 2]
155
+ x2 = x[..., x.shape[-1] // 2 :]
156
+ return torch.cat((-x2, x1), dim=-1)
157
+
158
+
159
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
160
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
161
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
162
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
163
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
164
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
165
+ q_embed = (q * cos) + (rotate_half(q) * sin)
166
+ k_embed = (k * cos) + (rotate_half(k) * sin)
167
+ return q_embed, k_embed
168
+
169
+
170
+ class SRV1MLP(nn.Module):
171
+ def __init__(self, prefix, config: SRV1Config, weigths):
172
+ super().__init__()
173
+ self.gate_proj = TensorParallelColumnLinear.load(
174
+ config=config, prefix=f"{prefix}.gate_proj", weights=weigths, bias=False
175
+ )
176
+ self.up_proj = TensorParallelColumnLinear.load(
177
+ config=config, prefix=f"{prefix}.up_proj", weights=weigths, bias=False
178
+ )
179
+ self.down_proj = TensorParallelRowLinear.load(
180
+ config=config, prefix=f"{prefix}.down_proj", weights=weigths, bias=False
181
+ )
182
+ self.act_fn = ACT2FN[config.hidden_act]
183
+
184
+ def forward(self, x):
185
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
186
+ return down_proj
187
+
188
+
189
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
190
+ """
191
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
192
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
193
+ """
194
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
195
+ if n_rep == 1:
196
+ return hidden_states
197
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
198
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
199
+
200
+
201
+ class SRV1Attention(nn.Module):
202
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
203
+
204
+ def __init__(self, prefix, config: SRV1Config, weights):
205
+ super().__init__()
206
+ self.config = config
207
+ self.hidden_size = config.hidden_size
208
+ self.num_heads = config.num_attention_heads
209
+ self.head_dim = self.hidden_size // self.num_heads
210
+ self.num_key_value_heads = config.num_key_value_heads
211
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
212
+ self.max_position_embeddings = config.max_position_embeddings
213
+ self.rope_theta = getattr(config, "rope_theta", 10000)
214
+
215
+ if (self.head_dim * self.num_heads) != self.hidden_size:
216
+ raise ValueError(
217
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
218
+ f" and `num_heads`: {self.num_heads})."
219
+ )
220
+
221
+ # for 1d tensor model parallel
222
+ process_group = weights.process_group
223
+ self.hidden_size = self.hidden_size // process_group.size()
224
+ self.num_heads = self.num_heads // process_group.size()
225
+ self.num_key_value_heads = self.num_key_value_heads // process_group.size()
226
+
227
+ self.q_proj = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.q_proj", weights=weights, bias=False)
228
+ self.k_proj = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.k_proj", weights=weights, bias=False)
229
+ self.v_proj = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.v_proj", weights=weights, bias=False)
230
+ self.o_proj = TensorParallelRowLinear.load(config, prefix=f"{prefix}.o_proj", weights=weights, bias=False)
231
+ if self.config.rope_scaling is not None and self.config.rope_scaling['type'] == "linear":
232
+ # Note, Not to use weights.device since rope should be calc on device cpu
233
+ # have to model.to(cur_rank) !!!
234
+ self.rotary_emb = SRV1RotaryEmbedding(
235
+ self.head_dim, self.max_position_embeddings, base=self.rope_theta, scaling_factor=self.config.rope_scaling['factor']
236
+ )
237
+ else:
238
+ self.rotary_emb = SRV1RotaryEmbedding(
239
+ self.head_dim, self.max_position_embeddings, base=self.rope_theta
240
+ )
241
+
242
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
243
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
244
+
245
+ def forward(
246
+ self,
247
+ hidden_states: torch.Tensor,
248
+ attention_mask: Optional[torch.Tensor] = None,
249
+ position_ids: Optional[torch.LongTensor] = None,
250
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
251
+ output_attentions: bool = False,
252
+ use_cache: bool = False,
253
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
254
+ bsz, q_len, _ = hidden_states.size()
255
+
256
+ query_states = self.q_proj(hidden_states)
257
+ key_states = self.k_proj(hidden_states)
258
+ value_states = self.v_proj(hidden_states)
259
+
260
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
261
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
262
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
263
+
264
+ kv_seq_len = key_states.shape[-2]
265
+ if past_key_value is not None:
266
+ kv_seq_len += past_key_value[0].shape[-2]
267
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
268
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
269
+
270
+ if past_key_value is not None:
271
+ # reuse k, v, self_attention
272
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
273
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
274
+
275
+ past_key_value = (key_states, value_states) if use_cache else None
276
+
277
+ # repeat k/v heads if n_kv_heads < n_heads
278
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
279
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
280
+
281
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
282
+
283
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
284
+ raise ValueError(
285
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
286
+ f" {attn_weights.size()}"
287
+ )
288
+
289
+ if attention_mask is not None:
290
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
291
+ raise ValueError(
292
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
293
+ )
294
+ attn_weights = attn_weights + attention_mask
295
+
296
+ # upcast attention to fp32
297
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
298
+ attn_output = torch.matmul(attn_weights, value_states)
299
+
300
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
301
+ raise ValueError(
302
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
303
+ f" {attn_output.size()}"
304
+ )
305
+
306
+ attn_output = attn_output.transpose(1, 2).contiguous()
307
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
308
+ attn_output = self.o_proj(attn_output)
309
+
310
+ if not output_attentions:
311
+ attn_weights = None
312
+
313
+ return attn_output, attn_weights, past_key_value
314
+
315
+
316
+ class SRV1DecoderLayer(nn.Module):
317
+ def __init__(self, prefix, config: SRV1Config, weights):
318
+ super().__init__()
319
+ self.hidden_size = config.hidden_size
320
+ self.self_attn = SRV1Attention(prefix=f"{prefix}.self_attn", config=config, weights=weights)
321
+ self.mlp = SRV1MLP(prefix=f"{prefix}.mlp", config=config, weigths=weights)
322
+ self.input_layernorm = SRV1RMSNorm.load_no_bias(
323
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
324
+ )
325
+ self.post_attention_layernorm = SRV1RMSNorm.load_no_bias(
326
+ prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps
327
+ )
328
+
329
+ def forward(
330
+ self,
331
+ hidden_states: torch.Tensor,
332
+ attention_mask: Optional[torch.Tensor] = None,
333
+ position_ids: Optional[torch.LongTensor] = None,
334
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
335
+ output_attentions: Optional[bool] = False,
336
+ use_cache: Optional[bool] = False,
337
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
338
+ """
339
+ Args:
340
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
341
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
342
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
343
+ output_attentions (`bool`, *optional*):
344
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
345
+ returned tensors for more detail.
346
+ use_cache (`bool`, *optional*):
347
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
348
+ (see `past_key_values`).
349
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
350
+ """
351
+
352
+ residual = hidden_states
353
+
354
+ hidden_states = self.input_layernorm(hidden_states)
355
+
356
+ # Self Attention
357
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
358
+ hidden_states=hidden_states,
359
+ attention_mask=attention_mask,
360
+ position_ids=position_ids,
361
+ past_key_value=past_key_value,
362
+ output_attentions=output_attentions,
363
+ use_cache=use_cache,
364
+ )
365
+ hidden_states = residual + hidden_states
366
+
367
+ # Fully Connected
368
+ residual = hidden_states
369
+ hidden_states = self.post_attention_layernorm(hidden_states)
370
+ hidden_states = self.mlp(hidden_states)
371
+ hidden_states = residual + hidden_states
372
+
373
+ outputs = (hidden_states,)
374
+
375
+ if output_attentions:
376
+ outputs += (self_attn_weights,)
377
+
378
+ if use_cache:
379
+ outputs += (present_key_value,)
380
+
381
+ return outputs
382
+
383
+
384
+ SRV1_START_DOCSTRING = r"""
385
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
386
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
387
+ etc.)
388
+
389
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
390
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
391
+ and behavior.
392
+
393
+ Parameters:
394
+ config ([`SRV1Config`]):
395
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
396
+ load the weights associated with the model, only the configuration. Check out the
397
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
398
+ """
399
+
400
+
401
+ @add_start_docstrings(
402
+ "The bare SRV1 Model outputting raw hidden-states without any specific head on top.",
403
+ SRV1_START_DOCSTRING,
404
+ )
405
+ class SRV1PreTrainedModel(PreTrainedModel):
406
+ config_class = SRV1Config
407
+ base_model_prefix = "model"
408
+ supports_gradient_checkpointing = True
409
+ _no_split_modules = ["SRV1DecoderLayer"]
410
+ _skip_keys_device_placement = "past_key_values"
411
+
412
+ def _init_weights(self, module):
413
+ std = self.config.initializer_range
414
+ if isinstance(module, nn.Linear):
415
+ module.weight.data.normal_(mean=0.0, std=std)
416
+ if module.bias is not None:
417
+ module.bias.data.zero_()
418
+ elif isinstance(module, nn.Embedding):
419
+ module.weight.data.normal_(mean=0.0, std=std)
420
+ if module.padding_idx is not None:
421
+ module.weight.data[module.padding_idx].zero_()
422
+
423
+ def _set_gradient_checkpointing(self, module, value=False):
424
+ if isinstance(module, SRV1Model):
425
+ module.gradient_checkpointing = value
426
+
427
+
428
+ SRV1_INPUTS_DOCSTRING = r"""
429
+ Args:
430
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
431
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
432
+ it.
433
+
434
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
435
+ [`PreTrainedTokenizer.__call__`] for details.
436
+
437
+ [What are input IDs?](../glossary#input-ids)
438
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
439
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
440
+
441
+ - 1 for tokens that are **not masked**,
442
+ - 0 for tokens that are **masked**.
443
+
444
+ [What are attention masks?](../glossary#attention-mask)
445
+
446
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
447
+ [`PreTrainedTokenizer.__call__`] for details.
448
+
449
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
450
+ `past_key_values`).
451
+
452
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
453
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
454
+ information on the default strategy.
455
+
456
+ - 1 indicates the head is **not masked**,
457
+ - 0 indicates the head is **masked**.
458
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
459
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
460
+ config.n_positions - 1]`.
461
+
462
+ [What are position IDs?](../glossary#position-ids)
463
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
464
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
465
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
466
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
467
+
468
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
469
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
470
+
471
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
472
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
473
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
474
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
475
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
476
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
477
+ model's internal embedding lookup matrix.
478
+ use_cache (`bool`, *optional*):
479
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
480
+ `past_key_values`).
481
+ output_attentions (`bool`, *optional*):
482
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
483
+ tensors for more detail.
484
+ output_hidden_states (`bool`, *optional*):
485
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
486
+ more detail.
487
+ return_dict (`bool`, *optional*):
488
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
489
+ """
490
+
491
+
492
+ @add_start_docstrings(
493
+ "The bare SRV1 Model outputting raw hidden-states without any specific head on top.",
494
+ SRV1_START_DOCSTRING,
495
+ )
496
+ class SRV1Model(SRV1PreTrainedModel):
497
+ """
498
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SRV1DecoderLayer`]
499
+
500
+ Args:
501
+ config: SRV1Config
502
+ """
503
+
504
+ def __init__(self, config: SRV1Config, weights):
505
+ super().__init__(config)
506
+ self.embed_tokens = TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights)
507
+ self.layers = nn.ModuleList(
508
+ [
509
+ SRV1DecoderLayer(prefix=f"model.layers.{_}", config=config, weights=weights)
510
+ for _ in range(config.num_hidden_layers)
511
+ ]
512
+ )
513
+ # self.norm = SRV1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
514
+ self.norm = SRV1RMSNorm.load_no_bias(prefix=f"model.norm", weights=weights, eps=config.rms_norm_eps)
515
+ self.gradient_checkpointing = False
516
+ # Initialize weights and apply final processing
517
+ self.post_init()
518
+
519
+ def get_input_embeddings(self):
520
+ return self.embed_tokens
521
+
522
+ def set_input_embeddings(self, value):
523
+ self.embed_tokens = value
524
+
525
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
526
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
527
+ # create causal mask
528
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
529
+ combined_attention_mask = None
530
+ if input_shape[-1] > 1:
531
+ combined_attention_mask = _make_causal_mask(
532
+ input_shape,
533
+ inputs_embeds.dtype,
534
+ device=inputs_embeds.device,
535
+ past_key_values_length=past_key_values_length,
536
+ )
537
+
538
+ if attention_mask is not None:
539
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
540
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
541
+ inputs_embeds.device
542
+ )
543
+ combined_attention_mask = (
544
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
545
+ )
546
+
547
+ return combined_attention_mask
548
+
549
+ @add_start_docstrings_to_model_forward(SRV1_INPUTS_DOCSTRING)
550
+ def forward(
551
+ self,
552
+ input_ids: torch.LongTensor = None,
553
+ attention_mask: Optional[torch.Tensor] = None,
554
+ position_ids: Optional[torch.LongTensor] = None,
555
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
556
+ inputs_embeds: Optional[torch.FloatTensor] = None,
557
+ use_cache: Optional[bool] = None,
558
+ output_attentions: Optional[bool] = None,
559
+ output_hidden_states: Optional[bool] = None,
560
+ return_dict: Optional[bool] = None,
561
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
562
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
563
+ output_hidden_states = (
564
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
565
+ )
566
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
567
+
568
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
569
+
570
+ # retrieve input_ids and inputs_embeds
571
+ if input_ids is not None and inputs_embeds is not None:
572
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
573
+ elif input_ids is not None:
574
+ batch_size, seq_length = input_ids.shape
575
+ elif inputs_embeds is not None:
576
+ batch_size, seq_length, _ = inputs_embeds.shape
577
+ else:
578
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
579
+
580
+ seq_length_with_past = seq_length
581
+ past_key_values_length = 0
582
+
583
+ if past_key_values is not None:
584
+ past_key_values_length = past_key_values[0][0].shape[2]
585
+ seq_length_with_past = seq_length_with_past + past_key_values_length
586
+
587
+ if position_ids is None:
588
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
589
+ position_ids = torch.arange(
590
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
591
+ )
592
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
593
+ else:
594
+ position_ids = position_ids.view(-1, seq_length).long()
595
+
596
+ if inputs_embeds is None:
597
+ inputs_embeds = self.embed_tokens(input_ids)
598
+ # embed positions
599
+ if attention_mask is None:
600
+ attention_mask = torch.ones(
601
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
602
+ )
603
+ attention_mask = self._prepare_decoder_attention_mask(
604
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
605
+ )
606
+
607
+ hidden_states = inputs_embeds
608
+
609
+ if self.gradient_checkpointing and self.training:
610
+ if use_cache:
611
+ logger.warning_once(
612
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
613
+ )
614
+ use_cache = False
615
+
616
+ # decoder layers
617
+ all_hidden_states = () if output_hidden_states else None
618
+ all_self_attns = () if output_attentions else None
619
+ next_decoder_cache = () if use_cache else None
620
+
621
+ for idx, decoder_layer in enumerate(self.layers):
622
+ if output_hidden_states:
623
+ all_hidden_states += (hidden_states,)
624
+
625
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
626
+
627
+ if self.gradient_checkpointing and self.training:
628
+
629
+ def create_custom_forward(module):
630
+ def custom_forward(*inputs):
631
+ # None for past_key_value
632
+ return module(*inputs, past_key_value, output_attentions)
633
+
634
+ return custom_forward
635
+
636
+ layer_outputs = torch.utils.checkpoint.checkpoint(
637
+ create_custom_forward(decoder_layer),
638
+ hidden_states,
639
+ attention_mask,
640
+ position_ids,
641
+ )
642
+ else:
643
+ layer_outputs = decoder_layer(
644
+ hidden_states,
645
+ attention_mask=attention_mask,
646
+ position_ids=position_ids,
647
+ past_key_value=past_key_value,
648
+ output_attentions=output_attentions,
649
+ use_cache=use_cache,
650
+ )
651
+
652
+ hidden_states = layer_outputs[0]
653
+
654
+ if use_cache:
655
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
656
+
657
+ if output_attentions:
658
+ all_self_attns += (layer_outputs[1],)
659
+
660
+ hidden_states = self.norm(hidden_states)
661
+
662
+ # add hidden states from the last decoder layer
663
+ if output_hidden_states:
664
+ all_hidden_states += (hidden_states,)
665
+
666
+ next_cache = next_decoder_cache if use_cache else None
667
+ if not return_dict:
668
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
669
+ return BaseModelOutputWithPast(
670
+ last_hidden_state=hidden_states,
671
+ past_key_values=next_cache,
672
+ hidden_states=all_hidden_states,
673
+ attentions=all_self_attns,
674
+ )
675
+
676
+
677
+ class SRV1ForCausalLM(SRV1PreTrainedModel):
678
+ _tied_weights_keys = ["lm_head.weight"]
679
+
680
+ def __init__(self, config, weights):
681
+ super().__init__(config)
682
+ self.model = SRV1Model(config, weights)
683
+ self.lm_head = TensorParallelHead.load(config, prefix="lm_head", weights=weights)
684
+ # Initialize weights and apply final processing
685
+ self.post_init()
686
+
687
+ def get_input_embeddings(self):
688
+ return self.model.embed_tokens
689
+
690
+ def set_input_embeddings(self, value):
691
+ self.model.embed_tokens = value
692
+
693
+ def get_output_embeddings(self):
694
+ return self.lm_head
695
+
696
+ def set_output_embeddings(self, new_embeddings):
697
+ self.lm_head = new_embeddings
698
+
699
+ def set_decoder(self, decoder):
700
+ self.model = decoder
701
+
702
+ def get_decoder(self):
703
+ return self.model
704
+
705
+ @add_start_docstrings_to_model_forward(SRV1_INPUTS_DOCSTRING)
706
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
707
+ def forward(
708
+ self,
709
+ input_ids: torch.LongTensor = None,
710
+ attention_mask: Optional[torch.Tensor] = None,
711
+ position_ids: Optional[torch.LongTensor] = None,
712
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
713
+ inputs_embeds: Optional[torch.FloatTensor] = None,
714
+ labels: Optional[torch.LongTensor] = None,
715
+ use_cache: Optional[bool] = None,
716
+ output_attentions: Optional[bool] = None,
717
+ output_hidden_states: Optional[bool] = None,
718
+ return_dict: Optional[bool] = None,
719
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
720
+ r"""
721
+ Args:
722
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
723
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
724
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
725
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
726
+
727
+ Returns:
728
+
729
+ Example:
730
+
731
+ ```python
732
+ >>> from transformers import AutoTokenizer, SRV1ForCausalLM
733
+
734
+ >>> model = SRV1ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
735
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
736
+
737
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
738
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
739
+
740
+ >>> # Generate
741
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
742
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
743
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
744
+ ```"""
745
+
746
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
747
+ output_hidden_states = (
748
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
749
+ )
750
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
751
+
752
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
753
+ outputs = self.model(
754
+ input_ids=input_ids,
755
+ attention_mask=attention_mask,
756
+ position_ids=position_ids,
757
+ past_key_values=past_key_values,
758
+ inputs_embeds=inputs_embeds,
759
+ use_cache=use_cache,
760
+ output_attentions=output_attentions,
761
+ output_hidden_states=output_hidden_states,
762
+ return_dict=return_dict,
763
+ )
764
+
765
+ hidden_states = outputs[0]
766
+ logits = self.lm_head(hidden_states)
767
+ logits = logits.float()
768
+
769
+ loss = None
770
+ if labels is not None:
771
+ # Shift so that tokens < n predict n
772
+ shift_logits = logits[..., :-1, :].contiguous()
773
+ shift_labels = labels[..., 1:].contiguous()
774
+ # Flatten the tokens
775
+ loss_fct = CrossEntropyLoss()
776
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
777
+ shift_labels = shift_labels.view(-1)
778
+ # Enable model parallelism
779
+ shift_labels = shift_labels.to(shift_logits.device)
780
+ loss = loss_fct(shift_logits, shift_labels)
781
+
782
+ if not return_dict:
783
+ output = (logits,) + outputs[1:]
784
+ return (loss,) + output if loss is not None else output
785
+
786
+ return CausalLMOutputWithPast(
787
+ loss=loss,
788
+ logits=logits,
789
+ past_key_values=outputs.past_key_values,
790
+ hidden_states=outputs.hidden_states,
791
+ attentions=outputs.attentions,
792
+ )
793
+
794
+ def prepare_inputs_for_generation(
795
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
796
+ ):
797
+ if past_key_values:
798
+ input_ids = input_ids[:, -1:]
799
+
800
+ position_ids = kwargs.get("position_ids", None)
801
+ if attention_mask is not None and position_ids is None:
802
+ # create position_ids on the fly for batch generation
803
+ position_ids = attention_mask.long().cumsum(-1) - 1
804
+ position_ids.masked_fill_(attention_mask == 0, 1)
805
+ if past_key_values:
806
+ position_ids = position_ids[:, -1].unsqueeze(-1)
807
+
808
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
809
+ if inputs_embeds is not None and past_key_values is None:
810
+ model_inputs = {"inputs_embeds": inputs_embeds}
811
+ else:
812
+ model_inputs = {"input_ids": input_ids}
813
+
814
+ model_inputs.update(
815
+ {
816
+ "position_ids": position_ids,
817
+ "past_key_values": past_key_values,
818
+ "use_cache": kwargs.get("use_cache"),
819
+ "attention_mask": attention_mask,
820
+ }
821
+ )
822
+ return model_inputs
823
+
824
+ @staticmethod
825
+ def _reorder_cache(past_key_values, beam_idx):
826
+ reordered_past = ()
827
+ for layer_past in past_key_values:
828
+ reordered_past += (
829
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
830
+ )
831
+ return reordered_past
832
+
833
+ class SRV1ForCausalLMParallel(SRV1ForCausalLM):
834
+ # def __init__(self, model_id:str, revision: Optional[str] = None,
835
+ # quantize: Optional[str] = None,
836
+ # dtype: Optional[torch.dtype] = None,
837
+ # trust_remote_code: bool = False):
838
+ def __init__(self, config, **kwargs):
839
+ model_id = kwargs.get("pretrained_model_name_or_path", None)
840
+ revision = kwargs.get("revision", None)
841
+ trust_remote_code = kwargs.get("trust_remote_code", False)
842
+ quantize = kwargs.get("quantize", None)
843
+ dtype = kwargs.get("dtype", None)
844
+ print("Start initializing...")
845
+ self.process_group, rank, world_size = initialize_torch_distributed()
846
+ print(f"RANK[{rank}]: Distributed Initialize Success")
847
+ if torch.cuda.is_available():
848
+ device = torch.device(f"cuda:{rank}")
849
+ dtype = torch.float16 if dtype is None else dtype
850
+ print(f"Use dtype {dtype}")
851
+ else:
852
+ raise NotImplementedError("Flash is only available on GPU")
853
+
854
+ print(f"Will read model dir {model_id}")
855
+ self.tokenizer = AutoTokenizer.from_pretrained(
856
+ model_id,
857
+ revision=revision,
858
+ padding_side="left",
859
+ truncation_side="left",
860
+ trust_remote_code=trust_remote_code,
861
+ )
862
+ # config already defined in from_pretrained
863
+ # config = SRV1Config.from_pretrained(model_id, revision=revision, trust_remote_code=trust_remote_code)
864
+ config.quantize = quantize
865
+ torch.distributed.barrier(group=self.process_group)
866
+ import glob
867
+ filenames = glob.glob(f"{model_id}/*.safetensors")
868
+ print(f"Will read filename {filenames}")
869
+ weights = Weights(filenames=filenames, device=device, dtype=dtype, process_group=self.process_group)
870
+ print(f"RANK[{rank}]: Loaded Weights success. device:{device}")
871
+
872
+ torch.distributed.barrier(group=self.process_group)
873
+ super(SRV1ForCausalLMParallel, self).__init__(
874
+ config=config,
875
+ weights=weights
876
+ )
877
+ print(f"RANK[{rank}]: parallel load success")
878
+ torch.distributed.barrier(group=self.process_group)
879
+
880
+ @classmethod
881
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, config=None, **kwargs):
882
+ config_path = config if config is not None else pretrained_model_name_or_path
883
+
884
+ config = cls.config_class.from_pretrained(
885
+ config_path,
886
+ **kwargs,
887
+ )
888
+ kwargs.update({"pretrained_model_name_or_path": pretrained_model_name_or_path})
889
+ model = cls(config, *model_args, **kwargs)
890
+ return model