Andrei Panferov commited on
Commit
5edaefc
·
1 Parent(s): 7e4a8ff
Files changed (3) hide show
  1. configuration_llama.py +1 -1
  2. inference.py +282 -8
  3. modeling_llama.py +348 -109
configuration_llama.py CHANGED
@@ -3,7 +3,7 @@ from transformers import LlamaConfig as OrigLlamaConfig
3
 
4
  class LlamaConfig(OrigLlamaConfig):
5
  model_type = "llama_aqlm"
6
-
7
  def __init__(
8
  self,
9
  nbits_per_codebook: int = 16,
 
3
 
4
  class LlamaConfig(OrigLlamaConfig):
5
  model_type = "llama_aqlm"
6
+
7
  def __init__(
8
  self,
9
  nbits_per_codebook: int = 16,
inference.py CHANGED
@@ -1,13 +1,13 @@
1
  """ Core mathematics for Additive Quantization (AQ): initialization, reconstruction and beam search"""
2
- import random
3
- from typing import List, Optional, Tuple, Union
 
4
 
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
-
9
- from src.inference_kernels.router import forward_pass_quantized_linear
10
- from src.utils import _dequantize_weight, ellipsis, get_int_dtype, unpack_int_data
11
 
12
 
13
  class FinalizedQuantizedLinear(nn.Module):
@@ -39,12 +39,17 @@ class FinalizedQuantizedLinear(nn.Module):
39
 
40
  # CODES & CODEBOOKS
41
  self.codebooks = nn.Parameter(
42
- torch.empty((num_codebooks, self.codebook_size, out_group_size, in_group_size), **factory_kwargs),
 
 
 
43
  requires_grad=True,
44
  ) # [num_codebooks, codebook_size, out_group_size, in_group_size]
45
  self.codes = nn.Parameter(
46
  torch.empty(
47
- (num_out_groups, num_in_groups, num_codebooks), device=device, dtype=get_int_dtype(nbits_per_codebook)
 
 
48
  ),
49
  requires_grad=False,
50
  ) # [num_out_groups, num_in_groups, num_codebooks]
@@ -61,4 +66,273 @@ class FinalizedQuantizedLinear(nn.Module):
61
  self.register_parameter("bias", None)
62
 
63
  def forward(self, input: torch.Tensor) -> torch.Tensor:
64
- return forward_pass_quantized_linear(input, self.codes, self.codebooks, self.scales, self.bias)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """ Core mathematics for Additive Quantization (AQ): initialization, reconstruction and beam search"""
2
+ import functools
3
+ import os
4
+ from typing import Optional
5
 
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
+ import triton
10
+ import triton.language as tl
 
11
 
12
 
13
  class FinalizedQuantizedLinear(nn.Module):
 
39
 
40
  # CODES & CODEBOOKS
41
  self.codebooks = nn.Parameter(
42
+ torch.empty(
43
+ (num_codebooks, self.codebook_size, out_group_size, in_group_size),
44
+ **factory_kwargs,
45
+ ),
46
  requires_grad=True,
47
  ) # [num_codebooks, codebook_size, out_group_size, in_group_size]
48
  self.codes = nn.Parameter(
49
  torch.empty(
50
+ (num_out_groups, num_in_groups, num_codebooks),
51
+ device=device,
52
+ dtype=get_int_dtype(nbits_per_codebook),
53
  ),
54
  requires_grad=False,
55
  ) # [num_out_groups, num_in_groups, num_codebooks]
 
66
  self.register_parameter("bias", None)
67
 
68
  def forward(self, input: torch.Tensor) -> torch.Tensor:
69
+ return forward_pass_quantized_linear(
70
+ input, self.codes, self.codebooks, self.scales, self.bias
71
+ )
72
+
73
+
74
+ def get_int_dtype(nbits: int) -> torch.dtype:
75
+ if nbits <= 8:
76
+ return torch.int8
77
+ if nbits <= 16:
78
+ return torch.int16
79
+ if nbits <= 32:
80
+ return torch.int32
81
+ if nbits <= 64:
82
+ return torch.int64
83
+ raise ValueError(f"No dtype available for {nbits}-bit codebooks")
84
+
85
+
86
+ @torch.inference_mode()
87
+ def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
88
+ return data.to(torch.int64) % (2**nbits)
89
+
90
+
91
+ @functools.lru_cache()
92
+ def maybe_script(fn: callable) -> callable:
93
+ """Apply torch.jit.script to function unless one is using TPU. TPU does not support torch.jit.script."""
94
+ using_tpu = bool(os.environ.get("TPU_NAME"))
95
+ # this is a reserved variable that must be set to TPU address (e.g. grpc://11.22.33.44:1337) for TPU to function
96
+ should_script = int(os.environ.get("AQ_USE_JIT", not using_tpu))
97
+ return torch.jit.script(fn) if should_script else fn
98
+
99
+
100
+ @maybe_script
101
+ def _dequantize_weight(
102
+ codes: torch.Tensor, codebooks: torch.Tensor, scales: Optional[torch.Tensor] = None
103
+ ) -> torch.Tensor:
104
+ """
105
+ Decode float weights from quantization codes. Differentiable.
106
+ :param codes: tensor of integer quantization codes, shape [*dims, num_out_groups, num_in_groups, num_codebooks]
107
+ :param codebooks: tensor of vectors for each quantization code, [num_codebooks, codebook_size, out_group_size, in_group_size]
108
+ :param scales: weight will be multiplied by this factor, must be broadcastble with [*dims, out_groups, num_in_groups, out_group_size, in_group_size]
109
+ :return: reconstructed weight tensor of shape [*dims, num_in_groups*group_size]
110
+ """
111
+ num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
112
+ num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
113
+ out_features = num_out_groups * out_group_size
114
+ in_features = num_in_groups * in_group_size
115
+ codebook_offsets = torch.arange(
116
+ 0, num_codebooks * codebook_size, codebook_size, device=codes.device
117
+ ) # shape: [num_codebooks]
118
+ reconstructed_weight_flat = F.embedding_bag(
119
+ codes.flatten(0, -2) + codebook_offsets,
120
+ codebooks.flatten(0, 1).flatten(-2, -1),
121
+ mode="sum",
122
+ ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size * in_group_size]
123
+
124
+ reconstructed_weight_groupwise = reconstructed_weight_flat.view(
125
+ list(codes.shape[:-3])
126
+ + [num_out_groups, num_in_groups, out_group_size, in_group_size]
127
+ )
128
+ if scales is not None:
129
+ reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales)
130
+ return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape(
131
+ list(codes.shape[:-3]) + [out_features, in_features]
132
+ )
133
+
134
+
135
+ def forward_pass_quantized_linear(
136
+ input: torch.Tensor,
137
+ codes: torch.IntTensor,
138
+ codebooks: torch.Tensor,
139
+ scales: torch.Tensor,
140
+ bias: Optional[torch.Tensor],
141
+ ) -> torch.Tensor:
142
+ if input.is_cuda:
143
+ matmul_result = aqlm_gemm_stupid(input, codes, codebooks, scales)
144
+ if bias is not None:
145
+ matmul_result += bias
146
+ return matmul_result
147
+ else:
148
+ dequantized_weight = _dequantize_weight(
149
+ unpack_int_data(codes, codebooks.shape[0].bit_length() - 1),
150
+ codebooks,
151
+ scales,
152
+ )
153
+ return F.linear(input, dequantized_weight, bias)
154
+
155
+
156
+ @triton.autotune(
157
+ configs=[
158
+ triton.Config({"UNUSED": 1}, num_stages=num_stages, num_warps=num_warps)
159
+ for num_stages in (1, 2, 3, 4, 5)
160
+ for num_warps in (1, 2, 4, 8)
161
+ ],
162
+ key=[
163
+ "in_features",
164
+ "out_features",
165
+ "num_codebooks",
166
+ "codebook_size",
167
+ "out_group_size",
168
+ "in_group_size",
169
+ "num_input_groups",
170
+ "num_input_groups_next_power_of_2",
171
+ "compute_in_fp32",
172
+ ],
173
+ )
174
+ @triton.jit
175
+ def _aqlm_gemv_simple(
176
+ input_vec_ptr,
177
+ output_vec_ptr,
178
+ codes_i16_ptr,
179
+ codebooks_ptr,
180
+ scales_ptr,
181
+ in_features: tl.constexpr,
182
+ out_features: tl.constexpr,
183
+ num_codebooks: tl.constexpr,
184
+ codebook_size: tl.constexpr,
185
+ out_group_size: tl.constexpr,
186
+ in_group_size: tl.constexpr,
187
+ num_input_groups: tl.constexpr,
188
+ num_input_groups_next_power_of_2: tl.constexpr,
189
+ compute_in_fp32: tl.constexpr,
190
+ UNUSED: tl.constexpr,
191
+ ):
192
+ # variables ending with "_i" mean "for i-th output unit"
193
+ pid = tl.program_id(axis=0) # [0, 1, ... {out_features-1}]
194
+
195
+ # Stage 1: load input data
196
+ input_vec = tl.load(
197
+ input_vec_ptr
198
+ + tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] * in_group_size
199
+ + tl.arange(0, in_group_size)[None, None, :],
200
+ mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None]
201
+ < num_input_groups,
202
+ )
203
+ # [in_features//in_group_size, 1, group_size]
204
+ # Note: we could simply load input_vec then reshape
205
+ # input_vec = tl.load(input_vec_ptr + tl.arange(0, in_features)) # [in_features]
206
+ # input_vec = tl.view(input_vec, [num_input_groups, 1, in_group_size])
207
+ # , but this does not work because tl.view may reorder elements arbitrarily; see its docstring
208
+
209
+ # Stage 2: load integer codes for the active row
210
+ # [in_features // in_group_size, num_codebooks]
211
+ codes_i_ptrs = (
212
+ codes_i16_ptr
213
+ + pid * num_input_groups * num_codebooks
214
+ + tl.arange(0, num_input_groups_next_power_of_2)[:, None] * num_codebooks
215
+ + tl.arange(0, num_codebooks)[None, :]
216
+ )
217
+ codes_i_mask_1d = tl.arange(0, num_input_groups_next_power_of_2) < num_input_groups
218
+
219
+ codes_i = tl.load(
220
+ codes_i_ptrs, mask=codes_i_mask_1d[:, None]
221
+ ) # [in_features//in_group_size, num_codebooks]
222
+ if codes_i.dtype == tl.int16:
223
+ codes_i = codes_i.to(tl.int32)
224
+ codes_i = (codes_i) + (
225
+ codes_i < 0
226
+ ) * codebook_size # aka 2 ** nbits_per_codebook
227
+ # ^-- (because codes are int16 tensors that contain uint data)
228
+
229
+ # The following alternative does not work:
230
+ # codes_i = codes_i.to(tl.int32) % codebook_size # aka 2 ** nbits_per_codebook
231
+ else:
232
+ codes_i = codes_i.to(tl.int32)
233
+
234
+ # shift codes_i so that codebooks after 0th point to correct indices in codebooks_ptr
235
+ codes_i += (
236
+ tl.arange(0, num_codebooks)[None, :] * codebook_size
237
+ ) # aka 2 ** nbits_per_codebook
238
+ # ^-- [in_group_size, num_codebooks]
239
+
240
+ # Stage 3: convert codes to pointers to every individual (activated) weight in codebooks
241
+ # [in_features // in_group_size, num_codebooks, out_group_size, in_group_size]
242
+ out_group_ix = tl.arange(0, out_group_size)[None, None, :, None]
243
+ in_group_ix = tl.arange(0, in_group_size)[None, None, None, :]
244
+ weight_i_ptrs = (
245
+ codebooks_ptr
246
+ + codes_i[:, :, None, None] * out_group_size * in_group_size
247
+ + out_group_ix * in_group_size
248
+ + in_group_ix
249
+ )
250
+
251
+ # Stage 4: reconstruct weights, multiply by inputs and write out
252
+ weights_i = tl.load(
253
+ weight_i_ptrs, mask=codes_i_mask_1d[:, None, None, None], other=0
254
+ )
255
+ if compute_in_fp32:
256
+ weights_i = weights_i.to(tl.float32)
257
+ input_vec = input_vec.to(tl.float32)
258
+ # ^-- [in_features // in_group_size, num_codebooks, out_group_size, in_group_size]
259
+ weights_i = tl.sum(weights_i, axis=1) # sum codebooks as per additive quantization
260
+ # ^-- [in_features // in_group_size, out_group_size, in_group_size]
261
+
262
+ if out_group_size == 1:
263
+ scale = tl.load(scales_ptr + pid).to(weights_i.dtype) # scalar
264
+ output_i = tl.sum(weights_i * input_vec) * scale
265
+ tl.store(output_vec_ptr + pid, output_i.to(input_vec.dtype))
266
+ else:
267
+ output_i = tl.sum(
268
+ tl.sum(weights_i * input_vec, axis=2), axis=0
269
+ ) # [out_group_size]
270
+ output_i *= tl.load(scales_ptr + pid).to(weights_i.dtype)
271
+ tl.store(
272
+ output_vec_ptr + pid * out_group_size + tl.arange(0, out_group_size),
273
+ output_i.to(input_vec.dtype),
274
+ )
275
+
276
+
277
+ def next_power_of_2(x):
278
+ return 1 if x == 0 else 2 ** (x - 1).bit_length()
279
+
280
+
281
+ def aqlm_gemv_simple(
282
+ input_vec: torch.Tensor,
283
+ codes_i16: torch.ShortTensor,
284
+ codebooks: torch.Tensor,
285
+ scales: torch.Tensor,
286
+ compute_in_fp32: bool = True,
287
+ ):
288
+
289
+ device, dtype = codebooks.device, codebooks.dtype
290
+ num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
291
+ in_features = input_vec.shape[1]
292
+ out_features = codes_i16.shape[0] * out_group_size
293
+ num_input_groups = codes_i16.shape[1]
294
+ assert input_vec.ndim == 2 and input_vec.shape[0] == 1, "do reshape; now!"
295
+ assert scales.shape == (out_features // out_group_size, 1, 1, 1)
296
+ assert in_features % in_group_size == 0
297
+ assert codebooks.shape[1] == 2**16
298
+
299
+ output_vec = torch.empty(1, out_features, device=device, dtype=dtype)
300
+ # 1D launch kernel where each block computes output unit
301
+ grid = lambda META: (out_features // out_group_size,)
302
+ _aqlm_gemv_simple[grid](
303
+ input_vec,
304
+ output_vec,
305
+ codes_i16,
306
+ codebooks,
307
+ scales,
308
+ in_features,
309
+ out_features,
310
+ num_codebooks,
311
+ codebook_size,
312
+ out_group_size,
313
+ in_group_size,
314
+ num_input_groups,
315
+ next_power_of_2(num_input_groups),
316
+ compute_in_fp32,
317
+ )
318
+
319
+ return output_vec
320
+
321
+
322
+ def aqlm_gemm_stupid(
323
+ input: torch.Tensor,
324
+ codes_i16: torch.ShortTensor,
325
+ codebooks: torch.Tensor,
326
+ scales: torch.Tensor,
327
+ compute_in_fp32: bool = True,
328
+ ):
329
+ original_shape = input.shape
330
+ input = input.reshape(-1, original_shape[-1])
331
+ return torch.cat(
332
+ [
333
+ aqlm_gemv_simple(
334
+ input_vec.unsqueeze(0), codes_i16, codebooks, scales, compute_in_fp32
335
+ )
336
+ for input_vec in input
337
+ ]
338
+ ).reshape(original_shape[:-1] + (-1,))
modeling_llama.py CHANGED
@@ -27,27 +27,23 @@ import torch.utils.checkpoint
27
  from torch import nn
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
  from transformers.activations import ACT2FN
30
- from transformers.modeling_outputs import (
31
- BaseModelOutputWithPast,
32
- CausalLMOutputWithPast,
33
- SequenceClassifierOutputWithPast,
34
- )
35
  from transformers.modeling_utils import PreTrainedModel
36
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
37
- from transformers.utils import (
38
- add_start_docstrings,
39
- add_start_docstrings_to_model_forward,
40
- is_flash_attn_available,
41
- logging,
42
- replace_return_docstrings,
43
- )
44
 
45
  from .configuration_llama import LlamaConfig
46
- from src.inference import FinalizedQuantizedLinear
47
 
48
  if is_flash_attn_available():
49
  from flash_attn import flash_attn_func, flash_attn_varlen_func
50
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
51
 
52
 
53
  logger = logging.get_logger(__name__)
@@ -59,7 +55,9 @@ def _get_unpad_data(padding_mask):
59
  seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
60
  indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
61
  max_seqlen_in_batch = seqlens_in_batch.max().item()
62
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
 
 
63
  return (
64
  indices,
65
  cu_seqlens,
@@ -69,7 +67,10 @@ def _get_unpad_data(padding_mask):
69
 
70
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
71
  def _make_causal_mask(
72
- input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
 
 
 
73
  ):
74
  """
75
  Make causal mask used for bi-directional self-attention.
@@ -81,8 +82,18 @@ def _make_causal_mask(
81
  mask = mask.to(dtype)
82
 
83
  if past_key_values_length > 0:
84
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
85
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
 
 
 
 
 
 
 
 
 
 
86
 
87
 
88
  # Copied from transformers.models.bart.modeling_bart._expand_mask
@@ -97,7 +108,9 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
97
 
98
  inverted_mask = 1.0 - expanded_mask
99
 
100
- return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
 
 
101
 
102
 
103
  class LlamaRMSNorm(nn.Module):
@@ -127,23 +140,33 @@ class LlamaRotaryEmbedding(nn.Module):
127
  self.dim = dim
128
  self.max_position_embeddings = max_position_embeddings
129
  self.base = base
130
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
 
 
131
  self.register_buffer("inv_freq", inv_freq, persistent=False)
132
 
133
  # Build here to make `torch.jit.trace` work.
134
  self._set_cos_sin_cache(
135
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
 
 
136
  )
137
 
138
  def _set_cos_sin_cache(self, seq_len, device, dtype):
139
  self.max_seq_len_cached = seq_len
140
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
 
 
141
 
142
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
143
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
144
  emb = torch.cat((freqs, freqs), dim=-1)
145
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
146
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
 
 
 
 
147
 
148
  def forward(self, x, seq_len=None):
149
  # x: [bs, num_attention_heads, seq_len, head_size]
@@ -159,26 +182,46 @@ class LlamaRotaryEmbedding(nn.Module):
159
  class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
160
  """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
161
 
162
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
 
 
 
 
 
 
 
163
  self.scaling_factor = scaling_factor
164
  super().__init__(dim, max_position_embeddings, base, device)
165
 
166
  def _set_cos_sin_cache(self, seq_len, device, dtype):
167
  self.max_seq_len_cached = seq_len
168
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
 
 
169
  t = t / self.scaling_factor
170
 
171
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
172
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
173
  emb = torch.cat((freqs, freqs), dim=-1)
174
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
175
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
 
 
 
 
176
 
177
 
178
  class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
179
  """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
180
 
181
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
 
 
 
 
 
 
 
182
  self.scaling_factor = scaling_factor
183
  super().__init__(dim, max_position_embeddings, base, device)
184
 
@@ -187,18 +230,27 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
187
 
188
  if seq_len > self.max_position_embeddings:
189
  base = self.base * (
190
- (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
 
191
  ) ** (self.dim / (self.dim - 2))
192
- inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
 
 
193
  self.register_buffer("inv_freq", inv_freq, persistent=False)
194
 
195
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
 
 
196
 
197
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
198
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
199
  emb = torch.cat((freqs, freqs), dim=-1)
200
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
201
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
 
 
 
 
202
 
203
 
204
  def rotate_half(x):
@@ -225,9 +277,15 @@ class LlamaMLP(nn.Module):
225
  self.config = config
226
  self.hidden_size = config.hidden_size
227
  self.intermediate_size = config.intermediate_size
228
- self.gate_proj = FinalizedQuantizedLinear(self.hidden_size, self.intermediate_size, bias=False, **config.aqlm)
229
- self.up_proj = FinalizedQuantizedLinear(self.hidden_size, self.intermediate_size, bias=False, **config.aqlm)
230
- self.down_proj = FinalizedQuantizedLinear(self.intermediate_size, self.hidden_size, bias=False, **config.aqlm)
 
 
 
 
 
 
231
  self.act_fn = ACT2FN[config.hidden_act]
232
 
233
  def forward(self, x):
@@ -237,12 +295,25 @@ class LlamaMLP(nn.Module):
237
  up_proj_slices = self.up_proj.weight.split(slice, dim=0)
238
  down_proj_slices = self.down_proj.weight.split(slice, dim=1)
239
 
240
- gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
241
- up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
  intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
244
  down_proj = [
245
- F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
 
246
  ]
247
  down_proj = sum(down_proj)
248
  else:
@@ -259,7 +330,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
259
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
260
  if n_rep == 1:
261
  return hidden_states
262
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
 
 
263
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
264
 
265
 
@@ -283,16 +356,28 @@ class LlamaAttention(nn.Module):
283
  f" and `num_heads`: {self.num_heads})."
284
  )
285
  self.q_proj = FinalizedQuantizedLinear(
286
- self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias, **config.aqlm
 
 
 
287
  )
288
  self.k_proj = FinalizedQuantizedLinear(
289
- self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, **config.aqlm
 
 
 
290
  )
291
  self.v_proj = FinalizedQuantizedLinear(
292
- self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, **config.aqlm
 
 
 
293
  )
294
  self.o_proj = FinalizedQuantizedLinear(
295
- self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias, **config.aqlm
 
 
 
296
  )
297
  self._init_rope()
298
 
@@ -324,7 +409,11 @@ class LlamaAttention(nn.Module):
324
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
325
 
326
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
327
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
 
 
 
 
328
 
329
  def forward(
330
  self,
@@ -339,20 +428,31 @@ class LlamaAttention(nn.Module):
339
  bsz, q_len, _ = hidden_states.size()
340
 
341
  if self.config.pretraining_tp > 1:
342
- key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
 
 
343
  query_slices = self.q_proj.weight.split(
344
  (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
345
  )
346
  key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
347
  value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
348
 
349
- query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
 
 
 
350
  query_states = torch.cat(query_states, dim=-1)
351
 
352
- key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
 
 
 
353
  key_states = torch.cat(key_states, dim=-1)
354
 
355
- value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
 
 
 
356
  value_states = torch.cat(value_states, dim=-1)
357
 
358
  else:
@@ -360,15 +460,23 @@ class LlamaAttention(nn.Module):
360
  key_states = self.k_proj(hidden_states)
361
  value_states = self.v_proj(hidden_states)
362
 
363
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
364
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
365
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
366
 
367
  kv_seq_len = key_states.shape[-2]
368
  if past_key_value is not None:
369
  kv_seq_len += past_key_value[0].shape[-2]
370
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
371
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
 
372
 
373
  if past_key_value is not None:
374
  # reuse k, v, self_attention
@@ -380,7 +488,9 @@ class LlamaAttention(nn.Module):
380
  key_states = repeat_kv(key_states, self.num_key_value_groups)
381
  value_states = repeat_kv(value_states, self.num_key_value_groups)
382
 
383
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
 
384
 
385
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
386
  raise ValueError(
@@ -396,7 +506,9 @@ class LlamaAttention(nn.Module):
396
  attn_weights = attn_weights + attention_mask
397
 
398
  # upcast attention to fp32
399
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 
 
400
  attn_output = torch.matmul(attn_weights, value_states)
401
 
402
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
@@ -410,9 +522,18 @@ class LlamaAttention(nn.Module):
410
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
411
 
412
  if self.config.pretraining_tp > 1:
413
- attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
414
- o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
415
- attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
 
 
 
 
 
 
 
 
 
416
  else:
417
  attn_output = self.o_proj(attn_output)
418
 
@@ -451,9 +572,15 @@ class LlamaFlashAttention2(LlamaAttention):
451
  # Flash attention requires the input to have the shape
452
  # batch_size x seq_length x head_dime x hidden_dim
453
  # therefore we just need to keep the original shape
454
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
455
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
456
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
457
 
458
  kv_seq_len = key_states.shape[-2]
459
  if past_key_value is not None:
@@ -461,7 +588,9 @@ class LlamaFlashAttention2(LlamaAttention):
461
 
462
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
463
 
464
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
 
465
 
466
  if past_key_value is not None:
467
  # reuse k, v, self_attention
@@ -497,7 +626,12 @@ class LlamaFlashAttention2(LlamaAttention):
497
  value_states = value_states.to(torch.float16)
498
 
499
  attn_output = self._flash_attention_forward(
500
- query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate
 
 
 
 
 
501
  )
502
 
503
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
@@ -509,7 +643,14 @@ class LlamaFlashAttention2(LlamaAttention):
509
  return attn_output, attn_weights, past_key_value
510
 
511
  def _flash_attention_forward(
512
- self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None
 
 
 
 
 
 
 
513
  ):
514
  """
515
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
@@ -533,7 +674,14 @@ class LlamaFlashAttention2(LlamaAttention):
533
  # Contains at least one padding token in the sequence
534
  if padding_mask is not None:
535
  batch_size = query_states.shape[0]
536
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
 
 
 
 
 
 
 
537
  query_states, key_states, value_states, padding_mask, query_length
538
  )
539
 
@@ -553,27 +701,39 @@ class LlamaFlashAttention2(LlamaAttention):
553
  causal=True,
554
  )
555
 
556
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
 
 
557
  else:
558
  attn_output = flash_attn_func(
559
- query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
 
 
 
 
 
560
  )
561
 
562
  return attn_output
563
 
564
- def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length):
 
 
565
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
566
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
567
 
568
  key_layer = index_first_axis(
569
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
 
570
  )
571
  value_layer = index_first_axis(
572
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
 
573
  )
574
  if query_length == kv_seq_len:
575
  query_layer = index_first_axis(
576
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
 
577
  )
578
  cu_seqlens_q = cu_seqlens_k
579
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
@@ -588,7 +748,9 @@ class LlamaFlashAttention2(LlamaAttention):
588
  else:
589
  # The -q_len: slice assumes left padding.
590
  padding_mask = padding_mask[:, -query_length:]
591
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask)
 
 
592
 
593
  return (
594
  query_layer,
@@ -611,7 +773,9 @@ class LlamaDecoderLayer(nn.Module):
611
  )
612
  self.mlp = LlamaMLP(config)
613
  self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
614
- self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
 
615
 
616
  def forward(
617
  self,
@@ -622,7 +786,9 @@ class LlamaDecoderLayer(nn.Module):
622
  output_attentions: Optional[bool] = False,
623
  use_cache: Optional[bool] = False,
624
  padding_mask: Optional[torch.LongTensor] = None,
625
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
 
626
  """
627
  Args:
628
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
@@ -796,8 +962,12 @@ class LlamaModel(LlamaPreTrainedModel):
796
  self.padding_idx = config.pad_token_id
797
  self.vocab_size = config.vocab_size
798
 
799
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
800
- self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
 
 
 
 
801
  self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
802
 
803
  self.gradient_checkpointing = False
@@ -811,7 +981,9 @@ class LlamaModel(LlamaPreTrainedModel):
811
  self.embed_tokens = value
812
 
813
  # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
814
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
 
 
815
  # create causal mask
816
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
817
  combined_attention_mask = None
@@ -825,11 +997,13 @@ class LlamaModel(LlamaPreTrainedModel):
825
 
826
  if attention_mask is not None:
827
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
828
- expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
829
- inputs_embeds.device
830
- )
831
  combined_attention_mask = (
832
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
 
 
833
  )
834
 
835
  return combined_attention_mask
@@ -847,17 +1021,27 @@ class LlamaModel(LlamaPreTrainedModel):
847
  output_hidden_states: Optional[bool] = None,
848
  return_dict: Optional[bool] = None,
849
  ) -> Union[Tuple, BaseModelOutputWithPast]:
850
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
851
  output_hidden_states = (
852
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
853
  )
854
  use_cache = use_cache if use_cache is not None else self.config.use_cache
855
 
856
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
857
 
858
  # retrieve input_ids and inputs_embeds
859
  if input_ids is not None and inputs_embeds is not None:
860
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
 
 
861
  elif input_ids is not None:
862
  batch_size, seq_length = input_ids.shape
863
  elif inputs_embeds is not None:
@@ -875,7 +1059,10 @@ class LlamaModel(LlamaPreTrainedModel):
875
  if position_ids is None:
876
  device = input_ids.device if input_ids is not None else inputs_embeds.device
877
  position_ids = torch.arange(
878
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
 
 
 
879
  )
880
  position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
881
  else:
@@ -886,7 +1073,9 @@ class LlamaModel(LlamaPreTrainedModel):
886
  # embed positions
887
  if attention_mask is None:
888
  attention_mask = torch.ones(
889
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
 
 
890
  )
891
  padding_mask = None
892
  else:
@@ -896,7 +1085,10 @@ class LlamaModel(LlamaPreTrainedModel):
896
  padding_mask = None
897
 
898
  attention_mask = self._prepare_decoder_attention_mask(
899
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
 
 
 
900
  )
901
 
902
  hidden_states = inputs_embeds
@@ -917,19 +1109,29 @@ class LlamaModel(LlamaPreTrainedModel):
917
  if output_hidden_states:
918
  all_hidden_states += (hidden_states,)
919
 
920
- past_key_value = past_key_values[idx] if past_key_values is not None else None
 
 
921
 
922
  if self.gradient_checkpointing and self.training:
923
 
924
  def create_custom_forward(module):
925
  def custom_forward(*inputs):
926
  # None for past_key_value
927
- return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask)
 
 
 
 
 
928
 
929
  return custom_forward
930
 
931
  layer_outputs = torch.utils.checkpoint.checkpoint(
932
- create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids
 
 
 
933
  )
934
  else:
935
  layer_outputs = decoder_layer(
@@ -958,7 +1160,11 @@ class LlamaModel(LlamaPreTrainedModel):
958
 
959
  next_cache = next_decoder_cache if use_cache else None
960
  if not return_dict:
961
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 
 
 
 
962
  return BaseModelOutputWithPast(
963
  last_hidden_state=hidden_states,
964
  past_key_values=next_cache,
@@ -998,7 +1204,9 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
998
  return self.model
999
 
1000
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1001
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
 
1002
  def forward(
1003
  self,
1004
  input_ids: torch.LongTensor = None,
@@ -1038,11 +1246,19 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1038
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1039
  ```"""
1040
 
1041
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
1042
  output_hidden_states = (
1043
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
1044
  )
1045
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1046
 
1047
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1048
  outputs = self.model(
@@ -1059,8 +1275,13 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1059
 
1060
  hidden_states = outputs[0]
1061
  if self.config.pretraining_tp > 1:
1062
- lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1063
- logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
 
 
 
 
 
1064
  logits = torch.cat(logits, dim=-1)
1065
  else:
1066
  logits = self.lm_head(hidden_states)
@@ -1092,7 +1313,12 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1092
  )
1093
 
1094
  def prepare_inputs_for_generation(
1095
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
 
 
 
 
 
1096
  ):
1097
  if past_key_values:
1098
  input_ids = input_ids[:, -1:]
@@ -1126,7 +1352,10 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1126
  reordered_past = ()
1127
  for layer_past in past_key_values:
1128
  reordered_past += (
1129
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
 
 
 
1130
  )
1131
  return reordered_past
1132
 
@@ -1182,7 +1411,9 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1182
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1183
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1184
  """
1185
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
1186
 
1187
  transformer_outputs = self.model(
1188
  input_ids,
@@ -1204,18 +1435,22 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1204
  batch_size = inputs_embeds.shape[0]
1205
 
1206
  if self.config.pad_token_id is None and batch_size != 1:
1207
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
 
 
1208
  if self.config.pad_token_id is None:
1209
  sequence_lengths = -1
1210
  else:
1211
  if input_ids is not None:
1212
- sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
1213
- logits.device
1214
- )
1215
  else:
1216
  sequence_lengths = -1
1217
 
1218
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
 
 
1219
 
1220
  loss = None
1221
  if labels is not None:
@@ -1223,7 +1458,9 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1223
  if self.config.problem_type is None:
1224
  if self.num_labels == 1:
1225
  self.config.problem_type = "regression"
1226
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
 
 
1227
  self.config.problem_type = "single_label_classification"
1228
  else:
1229
  self.config.problem_type = "multi_label_classification"
@@ -1236,7 +1473,9 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1236
  loss = loss_fct(pooled_logits, labels)
1237
  elif self.config.problem_type == "single_label_classification":
1238
  loss_fct = CrossEntropyLoss()
1239
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
 
 
1240
  elif self.config.problem_type == "multi_label_classification":
1241
  loss_fct = BCEWithLogitsLoss()
1242
  loss = loss_fct(pooled_logits, labels)
 
27
  from torch import nn
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
  from transformers.activations import ACT2FN
30
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
31
+ CausalLMOutputWithPast,
32
+ SequenceClassifierOutputWithPast)
 
 
33
  from transformers.modeling_utils import PreTrainedModel
34
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
35
+ from transformers.utils import (add_start_docstrings,
36
+ add_start_docstrings_to_model_forward,
37
+ is_flash_attn_available, logging,
38
+ replace_return_docstrings)
 
 
 
39
 
40
  from .configuration_llama import LlamaConfig
41
+ from .inference import FinalizedQuantizedLinear
42
 
43
  if is_flash_attn_available():
44
  from flash_attn import flash_attn_func, flash_attn_varlen_func
45
+ from flash_attn.bert_padding import (index_first_axis, pad_input, # noqa
46
+ unpad_input)
47
 
48
 
49
  logger = logging.get_logger(__name__)
 
55
  seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
56
  indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
57
  max_seqlen_in_batch = seqlens_in_batch.max().item()
58
+ cu_seqlens = F.pad(
59
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
60
+ )
61
  return (
62
  indices,
63
  cu_seqlens,
 
67
 
68
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
69
  def _make_causal_mask(
70
+ input_ids_shape: torch.Size,
71
+ dtype: torch.dtype,
72
+ device: torch.device,
73
+ past_key_values_length: int = 0,
74
  ):
75
  """
76
  Make causal mask used for bi-directional self-attention.
 
82
  mask = mask.to(dtype)
83
 
84
  if past_key_values_length > 0:
85
+ mask = torch.cat(
86
+ [
87
+ torch.zeros(
88
+ tgt_len, past_key_values_length, dtype=dtype, device=device
89
+ ),
90
+ mask,
91
+ ],
92
+ dim=-1,
93
+ )
94
+ return mask[None, None, :, :].expand(
95
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
96
+ )
97
 
98
 
99
  # Copied from transformers.models.bart.modeling_bart._expand_mask
 
108
 
109
  inverted_mask = 1.0 - expanded_mask
110
 
111
+ return inverted_mask.masked_fill(
112
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
113
+ )
114
 
115
 
116
  class LlamaRMSNorm(nn.Module):
 
140
  self.dim = dim
141
  self.max_position_embeddings = max_position_embeddings
142
  self.base = base
143
+ inv_freq = 1.0 / (
144
+ self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
145
+ )
146
  self.register_buffer("inv_freq", inv_freq, persistent=False)
147
 
148
  # Build here to make `torch.jit.trace` work.
149
  self._set_cos_sin_cache(
150
+ seq_len=max_position_embeddings,
151
+ device=self.inv_freq.device,
152
+ dtype=torch.get_default_dtype(),
153
  )
154
 
155
  def _set_cos_sin_cache(self, seq_len, device, dtype):
156
  self.max_seq_len_cached = seq_len
157
+ t = torch.arange(
158
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
159
+ )
160
 
161
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
162
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
163
  emb = torch.cat((freqs, freqs), dim=-1)
164
+ self.register_buffer(
165
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
166
+ )
167
+ self.register_buffer(
168
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
169
+ )
170
 
171
  def forward(self, x, seq_len=None):
172
  # x: [bs, num_attention_heads, seq_len, head_size]
 
182
  class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
183
  """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
184
 
185
+ def __init__(
186
+ self,
187
+ dim,
188
+ max_position_embeddings=2048,
189
+ base=10000,
190
+ device=None,
191
+ scaling_factor=1.0,
192
+ ):
193
  self.scaling_factor = scaling_factor
194
  super().__init__(dim, max_position_embeddings, base, device)
195
 
196
  def _set_cos_sin_cache(self, seq_len, device, dtype):
197
  self.max_seq_len_cached = seq_len
198
+ t = torch.arange(
199
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
200
+ )
201
  t = t / self.scaling_factor
202
 
203
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
204
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
205
  emb = torch.cat((freqs, freqs), dim=-1)
206
+ self.register_buffer(
207
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
208
+ )
209
+ self.register_buffer(
210
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
211
+ )
212
 
213
 
214
  class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
215
  """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
216
 
217
+ def __init__(
218
+ self,
219
+ dim,
220
+ max_position_embeddings=2048,
221
+ base=10000,
222
+ device=None,
223
+ scaling_factor=1.0,
224
+ ):
225
  self.scaling_factor = scaling_factor
226
  super().__init__(dim, max_position_embeddings, base, device)
227
 
 
230
 
231
  if seq_len > self.max_position_embeddings:
232
  base = self.base * (
233
+ (self.scaling_factor * seq_len / self.max_position_embeddings)
234
+ - (self.scaling_factor - 1)
235
  ) ** (self.dim / (self.dim - 2))
236
+ inv_freq = 1.0 / (
237
+ base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
238
+ )
239
  self.register_buffer("inv_freq", inv_freq, persistent=False)
240
 
241
+ t = torch.arange(
242
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
243
+ )
244
 
245
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
246
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
247
  emb = torch.cat((freqs, freqs), dim=-1)
248
+ self.register_buffer(
249
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
250
+ )
251
+ self.register_buffer(
252
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
253
+ )
254
 
255
 
256
  def rotate_half(x):
 
277
  self.config = config
278
  self.hidden_size = config.hidden_size
279
  self.intermediate_size = config.intermediate_size
280
+ self.gate_proj = FinalizedQuantizedLinear(
281
+ self.hidden_size, self.intermediate_size, bias=False, **config.aqlm
282
+ )
283
+ self.up_proj = FinalizedQuantizedLinear(
284
+ self.hidden_size, self.intermediate_size, bias=False, **config.aqlm
285
+ )
286
+ self.down_proj = FinalizedQuantizedLinear(
287
+ self.intermediate_size, self.hidden_size, bias=False, **config.aqlm
288
+ )
289
  self.act_fn = ACT2FN[config.hidden_act]
290
 
291
  def forward(self, x):
 
295
  up_proj_slices = self.up_proj.weight.split(slice, dim=0)
296
  down_proj_slices = self.down_proj.weight.split(slice, dim=1)
297
 
298
+ gate_proj = torch.cat(
299
+ [
300
+ F.linear(x, gate_proj_slices[i])
301
+ for i in range(self.config.pretraining_tp)
302
+ ],
303
+ dim=-1,
304
+ )
305
+ up_proj = torch.cat(
306
+ [
307
+ F.linear(x, up_proj_slices[i])
308
+ for i in range(self.config.pretraining_tp)
309
+ ],
310
+ dim=-1,
311
+ )
312
 
313
  intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
314
  down_proj = [
315
+ F.linear(intermediate_states[i], down_proj_slices[i])
316
+ for i in range(self.config.pretraining_tp)
317
  ]
318
  down_proj = sum(down_proj)
319
  else:
 
330
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
331
  if n_rep == 1:
332
  return hidden_states
333
+ hidden_states = hidden_states[:, :, None, :, :].expand(
334
+ batch, num_key_value_heads, n_rep, slen, head_dim
335
+ )
336
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
337
 
338
 
 
356
  f" and `num_heads`: {self.num_heads})."
357
  )
358
  self.q_proj = FinalizedQuantizedLinear(
359
+ self.hidden_size,
360
+ self.num_heads * self.head_dim,
361
+ bias=config.attention_bias,
362
+ **config.aqlm,
363
  )
364
  self.k_proj = FinalizedQuantizedLinear(
365
+ self.hidden_size,
366
+ self.num_key_value_heads * self.head_dim,
367
+ bias=config.attention_bias,
368
+ **config.aqlm,
369
  )
370
  self.v_proj = FinalizedQuantizedLinear(
371
+ self.hidden_size,
372
+ self.num_key_value_heads * self.head_dim,
373
+ bias=config.attention_bias,
374
+ **config.aqlm,
375
  )
376
  self.o_proj = FinalizedQuantizedLinear(
377
+ self.num_heads * self.head_dim,
378
+ self.hidden_size,
379
+ bias=config.attention_bias,
380
+ **config.aqlm,
381
  )
382
  self._init_rope()
383
 
 
409
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
410
 
411
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
412
+ return (
413
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
414
+ .transpose(1, 2)
415
+ .contiguous()
416
+ )
417
 
418
  def forward(
419
  self,
 
428
  bsz, q_len, _ = hidden_states.size()
429
 
430
  if self.config.pretraining_tp > 1:
431
+ key_value_slicing = (
432
+ self.num_key_value_heads * self.head_dim
433
+ ) // self.config.pretraining_tp
434
  query_slices = self.q_proj.weight.split(
435
  (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
436
  )
437
  key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
438
  value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
439
 
440
+ query_states = [
441
+ F.linear(hidden_states, query_slices[i])
442
+ for i in range(self.config.pretraining_tp)
443
+ ]
444
  query_states = torch.cat(query_states, dim=-1)
445
 
446
+ key_states = [
447
+ F.linear(hidden_states, key_slices[i])
448
+ for i in range(self.config.pretraining_tp)
449
+ ]
450
  key_states = torch.cat(key_states, dim=-1)
451
 
452
+ value_states = [
453
+ F.linear(hidden_states, value_slices[i])
454
+ for i in range(self.config.pretraining_tp)
455
+ ]
456
  value_states = torch.cat(value_states, dim=-1)
457
 
458
  else:
 
460
  key_states = self.k_proj(hidden_states)
461
  value_states = self.v_proj(hidden_states)
462
 
463
+ query_states = query_states.view(
464
+ bsz, q_len, self.num_heads, self.head_dim
465
+ ).transpose(1, 2)
466
+ key_states = key_states.view(
467
+ bsz, q_len, self.num_key_value_heads, self.head_dim
468
+ ).transpose(1, 2)
469
+ value_states = value_states.view(
470
+ bsz, q_len, self.num_key_value_heads, self.head_dim
471
+ ).transpose(1, 2)
472
 
473
  kv_seq_len = key_states.shape[-2]
474
  if past_key_value is not None:
475
  kv_seq_len += past_key_value[0].shape[-2]
476
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
477
+ query_states, key_states = apply_rotary_pos_emb(
478
+ query_states, key_states, cos, sin, position_ids
479
+ )
480
 
481
  if past_key_value is not None:
482
  # reuse k, v, self_attention
 
488
  key_states = repeat_kv(key_states, self.num_key_value_groups)
489
  value_states = repeat_kv(value_states, self.num_key_value_groups)
490
 
491
+ attn_weights = torch.matmul(
492
+ query_states, key_states.transpose(2, 3)
493
+ ) / math.sqrt(self.head_dim)
494
 
495
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
496
  raise ValueError(
 
506
  attn_weights = attn_weights + attention_mask
507
 
508
  # upcast attention to fp32
509
+ attn_weights = nn.functional.softmax(
510
+ attn_weights, dim=-1, dtype=torch.float32
511
+ ).to(query_states.dtype)
512
  attn_output = torch.matmul(attn_weights, value_states)
513
 
514
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 
522
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
523
 
524
  if self.config.pretraining_tp > 1:
525
+ attn_output = attn_output.split(
526
+ self.hidden_size // self.config.pretraining_tp, dim=2
527
+ )
528
+ o_proj_slices = self.o_proj.weight.split(
529
+ self.hidden_size // self.config.pretraining_tp, dim=1
530
+ )
531
+ attn_output = sum(
532
+ [
533
+ F.linear(attn_output[i], o_proj_slices[i])
534
+ for i in range(self.config.pretraining_tp)
535
+ ]
536
+ )
537
  else:
538
  attn_output = self.o_proj(attn_output)
539
 
 
572
  # Flash attention requires the input to have the shape
573
  # batch_size x seq_length x head_dime x hidden_dim
574
  # therefore we just need to keep the original shape
575
+ query_states = query_states.view(
576
+ bsz, q_len, self.num_heads, self.head_dim
577
+ ).transpose(1, 2)
578
+ key_states = key_states.view(
579
+ bsz, q_len, self.num_key_value_heads, self.head_dim
580
+ ).transpose(1, 2)
581
+ value_states = value_states.view(
582
+ bsz, q_len, self.num_key_value_heads, self.head_dim
583
+ ).transpose(1, 2)
584
 
585
  kv_seq_len = key_states.shape[-2]
586
  if past_key_value is not None:
 
588
 
589
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
590
 
591
+ query_states, key_states = apply_rotary_pos_emb(
592
+ query_states, key_states, cos, sin, position_ids
593
+ )
594
 
595
  if past_key_value is not None:
596
  # reuse k, v, self_attention
 
626
  value_states = value_states.to(torch.float16)
627
 
628
  attn_output = self._flash_attention_forward(
629
+ query_states,
630
+ key_states,
631
+ value_states,
632
+ padding_mask,
633
+ q_len,
634
+ dropout=dropout_rate,
635
  )
636
 
637
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
 
643
  return attn_output, attn_weights, past_key_value
644
 
645
  def _flash_attention_forward(
646
+ self,
647
+ query_states,
648
+ key_states,
649
+ value_states,
650
+ padding_mask,
651
+ query_length,
652
+ dropout=0.0,
653
+ softmax_scale=None,
654
  ):
655
  """
656
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
 
674
  # Contains at least one padding token in the sequence
675
  if padding_mask is not None:
676
  batch_size = query_states.shape[0]
677
+ (
678
+ query_states,
679
+ key_states,
680
+ value_states,
681
+ indices_q,
682
+ cu_seq_lens,
683
+ max_seq_lens,
684
+ ) = self._upad_input(
685
  query_states, key_states, value_states, padding_mask, query_length
686
  )
687
 
 
701
  causal=True,
702
  )
703
 
704
+ attn_output = pad_input(
705
+ attn_output_unpad, indices_q, batch_size, query_length
706
+ )
707
  else:
708
  attn_output = flash_attn_func(
709
+ query_states,
710
+ key_states,
711
+ value_states,
712
+ dropout,
713
+ softmax_scale=softmax_scale,
714
+ causal=True,
715
  )
716
 
717
  return attn_output
718
 
719
+ def _upad_input(
720
+ self, query_layer, key_layer, value_layer, padding_mask, query_length
721
+ ):
722
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
723
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
724
 
725
  key_layer = index_first_axis(
726
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
727
+ indices_k,
728
  )
729
  value_layer = index_first_axis(
730
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
731
+ indices_k,
732
  )
733
  if query_length == kv_seq_len:
734
  query_layer = index_first_axis(
735
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
736
+ indices_k,
737
  )
738
  cu_seqlens_q = cu_seqlens_k
739
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
 
748
  else:
749
  # The -q_len: slice assumes left padding.
750
  padding_mask = padding_mask[:, -query_length:]
751
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
752
+ query_layer, padding_mask
753
+ )
754
 
755
  return (
756
  query_layer,
 
773
  )
774
  self.mlp = LlamaMLP(config)
775
  self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
776
+ self.post_attention_layernorm = LlamaRMSNorm(
777
+ config.hidden_size, eps=config.rms_norm_eps
778
+ )
779
 
780
  def forward(
781
  self,
 
786
  output_attentions: Optional[bool] = False,
787
  use_cache: Optional[bool] = False,
788
  padding_mask: Optional[torch.LongTensor] = None,
789
+ ) -> Tuple[
790
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
791
+ ]:
792
  """
793
  Args:
794
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
 
962
  self.padding_idx = config.pad_token_id
963
  self.vocab_size = config.vocab_size
964
 
965
+ self.embed_tokens = nn.Embedding(
966
+ config.vocab_size, config.hidden_size, self.padding_idx
967
+ )
968
+ self.layers = nn.ModuleList(
969
+ [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
970
+ )
971
  self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
972
 
973
  self.gradient_checkpointing = False
 
981
  self.embed_tokens = value
982
 
983
  # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
984
+ def _prepare_decoder_attention_mask(
985
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
986
+ ):
987
  # create causal mask
988
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
989
  combined_attention_mask = None
 
997
 
998
  if attention_mask is not None:
999
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1000
+ expanded_attn_mask = _expand_mask(
1001
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1002
+ ).to(inputs_embeds.device)
1003
  combined_attention_mask = (
1004
+ expanded_attn_mask
1005
+ if combined_attention_mask is None
1006
+ else expanded_attn_mask + combined_attention_mask
1007
  )
1008
 
1009
  return combined_attention_mask
 
1021
  output_hidden_states: Optional[bool] = None,
1022
  return_dict: Optional[bool] = None,
1023
  ) -> Union[Tuple, BaseModelOutputWithPast]:
1024
+ output_attentions = (
1025
+ output_attentions
1026
+ if output_attentions is not None
1027
+ else self.config.output_attentions
1028
+ )
1029
  output_hidden_states = (
1030
+ output_hidden_states
1031
+ if output_hidden_states is not None
1032
+ else self.config.output_hidden_states
1033
  )
1034
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1035
 
1036
+ return_dict = (
1037
+ return_dict if return_dict is not None else self.config.use_return_dict
1038
+ )
1039
 
1040
  # retrieve input_ids and inputs_embeds
1041
  if input_ids is not None and inputs_embeds is not None:
1042
+ raise ValueError(
1043
+ "You cannot specify both input_ids and inputs_embeds at the same time"
1044
+ )
1045
  elif input_ids is not None:
1046
  batch_size, seq_length = input_ids.shape
1047
  elif inputs_embeds is not None:
 
1059
  if position_ids is None:
1060
  device = input_ids.device if input_ids is not None else inputs_embeds.device
1061
  position_ids = torch.arange(
1062
+ past_key_values_length,
1063
+ seq_length + past_key_values_length,
1064
+ dtype=torch.long,
1065
+ device=device,
1066
  )
1067
  position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1068
  else:
 
1073
  # embed positions
1074
  if attention_mask is None:
1075
  attention_mask = torch.ones(
1076
+ (batch_size, seq_length_with_past),
1077
+ dtype=torch.bool,
1078
+ device=inputs_embeds.device,
1079
  )
1080
  padding_mask = None
1081
  else:
 
1085
  padding_mask = None
1086
 
1087
  attention_mask = self._prepare_decoder_attention_mask(
1088
+ attention_mask,
1089
+ (batch_size, seq_length),
1090
+ inputs_embeds,
1091
+ past_key_values_length,
1092
  )
1093
 
1094
  hidden_states = inputs_embeds
 
1109
  if output_hidden_states:
1110
  all_hidden_states += (hidden_states,)
1111
 
1112
+ past_key_value = (
1113
+ past_key_values[idx] if past_key_values is not None else None
1114
+ )
1115
 
1116
  if self.gradient_checkpointing and self.training:
1117
 
1118
  def create_custom_forward(module):
1119
  def custom_forward(*inputs):
1120
  # None for past_key_value
1121
+ return module(
1122
+ *inputs,
1123
+ past_key_value,
1124
+ output_attentions,
1125
+ padding_mask=padding_mask,
1126
+ )
1127
 
1128
  return custom_forward
1129
 
1130
  layer_outputs = torch.utils.checkpoint.checkpoint(
1131
+ create_custom_forward(decoder_layer),
1132
+ hidden_states,
1133
+ attention_mask,
1134
+ position_ids,
1135
  )
1136
  else:
1137
  layer_outputs = decoder_layer(
 
1160
 
1161
  next_cache = next_decoder_cache if use_cache else None
1162
  if not return_dict:
1163
+ return tuple(
1164
+ v
1165
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1166
+ if v is not None
1167
+ )
1168
  return BaseModelOutputWithPast(
1169
  last_hidden_state=hidden_states,
1170
  past_key_values=next_cache,
 
1204
  return self.model
1205
 
1206
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1207
+ @replace_return_docstrings(
1208
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1209
+ )
1210
  def forward(
1211
  self,
1212
  input_ids: torch.LongTensor = None,
 
1246
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1247
  ```"""
1248
 
1249
+ output_attentions = (
1250
+ output_attentions
1251
+ if output_attentions is not None
1252
+ else self.config.output_attentions
1253
+ )
1254
  output_hidden_states = (
1255
+ output_hidden_states
1256
+ if output_hidden_states is not None
1257
+ else self.config.output_hidden_states
1258
+ )
1259
+ return_dict = (
1260
+ return_dict if return_dict is not None else self.config.use_return_dict
1261
  )
 
1262
 
1263
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1264
  outputs = self.model(
 
1275
 
1276
  hidden_states = outputs[0]
1277
  if self.config.pretraining_tp > 1:
1278
+ lm_head_slices = self.lm_head.weight.split(
1279
+ self.vocab_size // self.config.pretraining_tp, dim=0
1280
+ )
1281
+ logits = [
1282
+ F.linear(hidden_states, lm_head_slices[i])
1283
+ for i in range(self.config.pretraining_tp)
1284
+ ]
1285
  logits = torch.cat(logits, dim=-1)
1286
  else:
1287
  logits = self.lm_head(hidden_states)
 
1313
  )
1314
 
1315
  def prepare_inputs_for_generation(
1316
+ self,
1317
+ input_ids,
1318
+ past_key_values=None,
1319
+ attention_mask=None,
1320
+ inputs_embeds=None,
1321
+ **kwargs,
1322
  ):
1323
  if past_key_values:
1324
  input_ids = input_ids[:, -1:]
 
1352
  reordered_past = ()
1353
  for layer_past in past_key_values:
1354
  reordered_past += (
1355
+ tuple(
1356
+ past_state.index_select(0, beam_idx.to(past_state.device))
1357
+ for past_state in layer_past
1358
+ ),
1359
  )
1360
  return reordered_past
1361
 
 
1411
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1412
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1413
  """
1414
+ return_dict = (
1415
+ return_dict if return_dict is not None else self.config.use_return_dict
1416
+ )
1417
 
1418
  transformer_outputs = self.model(
1419
  input_ids,
 
1435
  batch_size = inputs_embeds.shape[0]
1436
 
1437
  if self.config.pad_token_id is None and batch_size != 1:
1438
+ raise ValueError(
1439
+ "Cannot handle batch sizes > 1 if no padding token is defined."
1440
+ )
1441
  if self.config.pad_token_id is None:
1442
  sequence_lengths = -1
1443
  else:
1444
  if input_ids is not None:
1445
+ sequence_lengths = (
1446
+ torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1
1447
+ ).to(logits.device)
1448
  else:
1449
  sequence_lengths = -1
1450
 
1451
+ pooled_logits = logits[
1452
+ torch.arange(batch_size, device=logits.device), sequence_lengths
1453
+ ]
1454
 
1455
  loss = None
1456
  if labels is not None:
 
1458
  if self.config.problem_type is None:
1459
  if self.num_labels == 1:
1460
  self.config.problem_type = "regression"
1461
+ elif self.num_labels > 1 and (
1462
+ labels.dtype == torch.long or labels.dtype == torch.int
1463
+ ):
1464
  self.config.problem_type = "single_label_classification"
1465
  else:
1466
  self.config.problem_type = "multi_label_classification"
 
1473
  loss = loss_fct(pooled_logits, labels)
1474
  elif self.config.problem_type == "single_label_classification":
1475
  loss_fct = CrossEntropyLoss()
1476
+ loss = loss_fct(
1477
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
1478
+ )
1479
  elif self.config.problem_type == "multi_label_classification":
1480
  loss_fct = BCEWithLogitsLoss()
1481
  loss = loss_fct(pooled_logits, labels)