Matt commited on
Commit
5700dc4
1 Parent(s): cebeac0

Revert to Falcon naming

Browse files
Files changed (3) hide show
  1. config.json +6 -6
  2. configuration_RW.py +0 -147
  3. modeling_RW.py +0 -1262
config.json CHANGED
@@ -6,12 +6,12 @@
6
  ],
7
  "attention_dropout": 0.0,
8
  "auto_map": {
9
- "AutoConfig": "configuration_RW.RWConfig",
10
- "AutoModel": "modeling_RW.RWModel",
11
- "AutoModelForSequenceClassification": "modeling_RW.RWForSequenceClassification",
12
- "AutoModelForTokenClassification": "modeling_RW.RWForTokenClassification",
13
- "AutoModelForQuestionAnswering": "modeling_RW.RWForQuestionAnswering",
14
- "AutoModelForCausalLM": "modeling_RW.RWForCausalLM"
15
  },
16
  "bias": true,
17
  "bos_token_id": 1,
 
6
  ],
7
  "attention_dropout": 0.0,
8
  "auto_map": {
9
+ "AutoConfig": "configuration_falcon.FalconConfig",
10
+ "AutoModel": "modeling_falcon.FalconModel",
11
+ "AutoModelForSequenceClassification": "modeling_falcon.FalconForSequenceClassification",
12
+ "AutoModelForTokenClassification": "modeling_falcon.FalconForTokenClassification",
13
+ "AutoModelForQuestionAnswering": "modeling_falcon.FalconForQuestionAnswering",
14
+ "AutoModelForCausalLM": "modeling_falcon.FalconForCausalLM"
15
  },
16
  "bias": true,
17
  "bos_token_id": 1,
configuration_RW.py DELETED
@@ -1,147 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ Falcon configuration"""
16
- from transformers.configuration_utils import PretrainedConfig
17
- from transformers.utils import logging
18
-
19
-
20
- logger = logging.get_logger(__name__)
21
-
22
- FALCON_PRETRAINED_CONFIG_ARCHIVE_MAP = {
23
- "tiiuae/falcon-40b": "https://huggingface.co/tiiuae/falcon-40b/resolve/main/config.json",
24
- "tiiuae/falcon-7b": "https://huggingface.co/tiiuae/falcon-7b/resolve/main/config.json",
25
- }
26
-
27
-
28
- class RWConfig(PretrainedConfig):
29
- r"""
30
- This is the configuration class to store the configuration of a [`FalconModel`]. It is used to instantiate a Falcon
31
- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
32
- defaults will yield a similar configuration to that of the
33
- [tiiuae/falcon-7b](https://huggingface.co/tiiuae/falcon-7b) architecture.
34
-
35
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
- documentation from [`PretrainedConfig`] for more information.
37
-
38
-
39
- Args:
40
- vocab_size (`int`, *optional*, defaults to 65024):
41
- Vocabulary size of the Falcon model. Defines the number of different tokens that can be represented by the
42
- `inputs_ids` passed when calling [`FalconModel`]
43
- hidden_size (`int`, *optional*, defaults to 4544):
44
- Dimension of the hidden representations.
45
- num_hidden_layers (`int`, *optional*, defaults to 32):
46
- Number of hidden layers in the Transformer decoder.
47
- num_attention_heads (`int`, *optional*, defaults to 71):
48
- Number of attention heads for each attention layer in the Transformer encoder.
49
- initializer_range (`float`, *optional*, defaults to 0.02):
50
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
51
- use_cache (`bool`, *optional*, defaults to `True`):
52
- Whether the model should return the last key/values attentions (not used by all models). Only relevant if
53
- `config.is_decoder=True`.
54
- layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
55
- The epsilon used by the layer normalization layers.
56
- hidden_dropout (`float`, *optional*, defaults to 0.0):
57
- The dropout probability for MLP layers.
58
- attention_dropout (`float`, *optional*, defaults to 0.0):
59
- The dropout probability for attention layers.
60
- num_kv_heads (`int`, *optional*):
61
- Number of key-value heads to use per attention layer. If unset, defaults to the same value as
62
- `num_attention_heads`.
63
- alibi (`bool`, *optional*, defaults to `False`):
64
- Whether to use ALiBi positional biases during self-attention.
65
- new_decoder_architecture (`bool`, *optional*, defaults to `False`):
66
- Whether to use the new (Falcon-40B) decoder architecture. If `True`, the `multi_query` and `parallel_attn`
67
- arguments are ignored, as the new decoder always uses parallel attention.
68
- multi_query (`bool`, *optional*, defaults to `True`):
69
- Whether to use multi-query attention in the decoder. Ignored when `new_decoder_architecture` is `True`.
70
- parallel_attn (`bool`, *optional*, defaults to `True`):
71
- Whether to compute attention in parallel with the feedforward layer. If False, they are consecutive
72
- instead, as in the original Transformer architecture. Ignored when `new_decoder_architecture` is `True`.
73
- bias (`bool`, *optional*, defaults to `False`):
74
- Whether to use bias on Linear layers.
75
- bos_token_id (`int`, *optional*, defaults to 11):
76
- The id of the "beginning-of-sequence" token.
77
- eos_token_id (`int`, *optional*, defaults to 11):
78
- The id of the "end-of-sequence" token.
79
-
80
- Example:
81
-
82
- ```python
83
- >>> from transformers import FalconModel, RWConfig
84
-
85
- >>> # Initializing a small (2-layer) Falcon configuration
86
- >>> configuration = RWConfig(num_hidden_layers=2)
87
-
88
- >>> # Initializing a model from the small configuration
89
- >>> model = FalconModel(configuration)
90
-
91
- >>> # Accessing the model configuration
92
- >>> configuration = model.config
93
- ```"""
94
- model_type = "falcon"
95
- keys_to_ignore_at_inference = ["past_key_values"]
96
-
97
- def __init__(
98
- self,
99
- vocab_size=65024,
100
- hidden_size=4544,
101
- num_hidden_layers=32,
102
- num_attention_heads=71,
103
- layer_norm_epsilon=1e-5,
104
- initializer_range=0.02,
105
- use_cache=True,
106
- hidden_dropout=0.0,
107
- attention_dropout=0.0,
108
- num_kv_heads=None,
109
- alibi=False,
110
- new_decoder_architecture=False,
111
- multi_query=True,
112
- parallel_attn=True,
113
- bias=False,
114
- bos_token_id=11,
115
- eos_token_id=11,
116
- **kwargs,
117
- ):
118
- self.vocab_size = vocab_size
119
- # Backward compatibility with n_embed kwarg
120
- n_embed = kwargs.pop("n_embed", None)
121
- self.hidden_size = hidden_size if n_embed is None else n_embed
122
- self.num_hidden_layers = num_hidden_layers
123
- self.num_attention_heads = num_attention_heads
124
- self.layer_norm_epsilon = layer_norm_epsilon
125
- self.initializer_range = initializer_range
126
- self.use_cache = use_cache
127
- self.hidden_dropout = hidden_dropout
128
- self.attention_dropout = attention_dropout
129
-
130
- self.bos_token_id = bos_token_id
131
- self.eos_token_id = eos_token_id
132
- self.num_kv_heads = num_attention_heads if num_kv_heads is None else num_kv_heads
133
- self.alibi = alibi
134
- self.new_decoder_architecture = new_decoder_architecture
135
- self.multi_query = multi_query # Ignored when new_decoder_architecture is True
136
- self.parallel_attn = parallel_attn
137
- self.bias = bias
138
-
139
- super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
140
-
141
- @property
142
- def head_dim(self):
143
- return self.hidden_size // self.num_attention_heads
144
-
145
- @property
146
- def rotary(self):
147
- return not self.alibi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_RW.py DELETED
@@ -1,1262 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """PyTorch Falcon model."""
16
-
17
- import math
18
- from typing import Optional, Tuple, Union
19
-
20
- import torch
21
- import torch.utils.checkpoint
22
- from torch import nn
23
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
24
- from torch.nn import functional as F
25
-
26
- from transformers.modeling_outputs import (
27
- BaseModelOutputWithPastAndCrossAttentions,
28
- CausalLMOutputWithCrossAttentions,
29
- QuestionAnsweringModelOutput,
30
- SequenceClassifierOutputWithPast,
31
- TokenClassifierOutput,
32
- )
33
- from transformers.modeling_utils import PreTrainedModel
34
- from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
35
- from .configuration_RW import RWConfig
36
-
37
-
38
- logger = logging.get_logger(__name__)
39
-
40
- FALCON_PRETRAINED_MODEL_ARCHIVE_LIST = [
41
- "tiiuae/falcon-40b",
42
- "tiiuae/falcon-40b-instruct",
43
- "tiiuae/falcon-7b",
44
- "tiiuae/falcon-7b-instruct",
45
- "tiiuae/falcon-rw-7b",
46
- "tiiuae/falcon-rw-1b",
47
- ]
48
- _CHECKPOINT_FOR_DOC = "Rocketknight1/falcon-rw-1b"
49
- _CONFIG_FOR_DOC = "RWConfig"
50
-
51
-
52
- # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
53
- # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
54
- class FalconLinear(nn.Linear):
55
- def forward(self, input: torch.Tensor) -> torch.Tensor:
56
- hidden_states = input @ self.weight.T
57
- if self.bias is None:
58
- return hidden_states
59
- return hidden_states + self.bias
60
-
61
-
62
- # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
63
- def rotate_half(x):
64
- x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
65
- return torch.cat((-x2, x1), dim=-1)
66
-
67
-
68
- class FalconRotaryEmbedding(nn.Module):
69
- """Implementation of RotaryEmbedding from GPT-NeoX.
70
- This implementation is designed to operate on queries and keys that are compatible with `[batch_size,
71
- n_heads_per_partition, seq_len, head_dim]` (e.g. MinGPTAttention format).
72
- """
73
-
74
- def __init__(self, head_dim: int, base=10000):
75
- super().__init__()
76
- inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
77
- self.register_buffer("inv_freq", inv_freq, persistent=False)
78
- self.head_dim = head_dim
79
- self.seq_len_cached = -1
80
- self.cos_cached: torch.Tensor | None = None
81
- self.sin_cached: torch.Tensor | None = None
82
-
83
- def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
84
- total_length = seq_len + past_key_values_length
85
- if total_length > self.seq_len_cached:
86
- self.seq_len_cached = total_length
87
- t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
88
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
89
- emb = torch.cat((freqs, freqs), dim=-1).to(device)
90
-
91
- if dtype in [torch.float16, torch.bfloat16]:
92
- emb = emb.float()
93
-
94
- self.cos_cached = emb.cos()[None, :, :]
95
- self.sin_cached = emb.sin()[None, :, :]
96
-
97
- self.cos_cached = self.cos_cached.type(dtype)
98
- self.sin_cached = self.sin_cached.type(dtype)
99
-
100
- return (
101
- self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
102
- self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
103
- )
104
-
105
- def forward(self, query, key, past_key_values_length=0):
106
- batch, seq_len, head_dim = query.shape
107
- cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
108
- return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
109
-
110
-
111
- def _make_causal_mask(
112
- input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
113
- ) -> torch.BoolTensor:
114
- """
115
- Make causal mask used for self-attention. This mask does not take the existing attention mask into account - it
116
- just blocks tokens from attending forwards in the sequence. The output shape will be `[batch_size, 1,
117
- target_length, target_length+past_key_values_length]`.
118
- """
119
- batch_size, target_length = input_ids_shape
120
-
121
- mask = torch.triu(torch.ones((target_length, target_length), dtype=torch.bool, device=device), diagonal=1)
122
- # If past_key_values_length is 0 this is an empty tensor and the concatenation is a no-op.
123
- # This code style is an unfortunate consequence of getting your TF engineer to port models; doing it this
124
- # way avoids a data-dependent conditional, which will help me when I have to port this to XLA later.
125
- past_mask = torch.zeros((target_length, past_key_values_length), dtype=torch.bool, device=device)
126
- mask = torch.cat([past_mask, mask], dim=-1)
127
- expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
128
- return expanded_mask
129
-
130
-
131
- def _expand_mask(mask: torch.Tensor, past_key_values_length: int) -> torch.BoolTensor:
132
- """
133
- Expands attention_mask from `[batch_size, seq_length]` to `[batch_size, 1, seq_length, seq_length + past_length]`.
134
- """
135
- batch_size, total_length = mask.shape
136
- seq_length = total_length - past_key_values_length if past_key_values_length is not None else total_length
137
-
138
- expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
139
- return expanded_mask.expand(batch_size, 1, seq_length, total_length)
140
-
141
-
142
- def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
143
- batch_size, seq_length = attention_mask.shape
144
- closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
145
- base = torch.tensor(
146
- 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
147
- )
148
- powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
149
- slopes = torch.pow(base, powers)
150
-
151
- if closest_power_of_2 != num_heads:
152
- extra_base = torch.tensor(
153
- 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
154
- )
155
- num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
156
- extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
157
- slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
158
-
159
- # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
160
- # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
161
- # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
162
- # => the query_length dimension will then be broadcasted correctly
163
- # This is more or less identical to T5's relative position bias:
164
- # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
165
- arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
166
- alibi = slopes[..., None].bfloat16() * arange_tensor
167
- return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
168
-
169
-
170
- # Copied from transformers.models.bloom.modeling_bloom.dropout_add
171
- def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
172
- """
173
- Dropout add function
174
-
175
- Args:
176
- x (`torch.tensor`, *required*):
177
- input tensor
178
- residual (`torch.tensor`, *required*):
179
- residual tensor
180
- prob (`float`, *required*):
181
- dropout probability
182
- training (`bool`, *required*):
183
- training mode
184
- """
185
- out = F.dropout(x, p=prob, training=training)
186
- out = residual + out
187
- return out
188
-
189
-
190
- class FalconAttention(nn.Module):
191
- def __init__(self, config: RWConfig):
192
- super().__init__()
193
-
194
- self.hidden_size = config.hidden_size
195
- self.num_heads = config.num_attention_heads
196
- self.head_dim = self.hidden_size // self.num_heads
197
- self.split_size = self.hidden_size
198
- self.hidden_dropout = config.hidden_dropout
199
-
200
- if self.head_dim * self.num_heads != self.hidden_size:
201
- raise ValueError(
202
- f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
203
- f" {self.num_heads})."
204
- )
205
-
206
- self.maybe_rotary = FalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k)
207
-
208
- # Layer-wise attention scaling
209
- self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
210
- self.beta = self.inv_norm_factor
211
- if config.new_decoder_architecture:
212
- qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
213
- elif config.multi_query:
214
- qkv_out_dim = self.hidden_size + 2 * self.head_dim
215
- else:
216
- qkv_out_dim = 3 * self.hidden_size
217
- self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
218
- self.new_decoder_architecture = config.new_decoder_architecture
219
- self.multi_query = config.multi_query
220
- self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
221
- self.attention_dropout = nn.Dropout(config.attention_dropout)
222
- self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
223
-
224
- def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
225
- """
226
- Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv`
227
-
228
- Args:
229
- fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
230
-
231
- Returns:
232
- query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
233
- value: [batch_size, seq_length, num_heads, head_dim]
234
- """
235
- if self.new_decoder_architecture:
236
- batch, seq_len, _ = fused_qkv.shape
237
- qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv_heads + 2, self.head_dim)
238
- query = qkv[:, :, :, :-2]
239
- key = qkv[:, :, :, [-2]]
240
- value = qkv[:, :, :, [-1]]
241
- key = torch.broadcast_to(key, query.shape)
242
- value = torch.broadcast_to(value, query.shape)
243
-
244
- query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
245
- return query, key, value
246
- elif not self.multi_query:
247
- batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
248
- fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
249
- return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
250
- else:
251
- batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
252
- fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
253
- return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
254
-
255
- # Copied from transformers.models.bloom.modeling_bloom.BloomAttention._merge_heads
256
- def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
257
- """
258
- Merge heads together over the last dimenstion
259
-
260
- Args:
261
- x (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
262
-
263
- Returns:
264
- torch.tensor: [batch_size, seq_length, num_heads * head_dim]
265
- """
266
- # What we want to achieve is:
267
- # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
268
- batch_size_and_num_heads, seq_length, _ = x.shape
269
- batch_size = batch_size_and_num_heads // self.num_heads
270
-
271
- # First view to decompose the batch size
272
- # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
273
- x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
274
-
275
- # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
276
- x = x.permute(0, 2, 1, 3)
277
-
278
- # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
279
- return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
280
-
281
- def forward(
282
- self,
283
- hidden_states: torch.Tensor,
284
- alibi: Optional[torch.Tensor],
285
- attention_mask: torch.Tensor,
286
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
287
- head_mask: Optional[torch.Tensor] = None,
288
- use_cache: bool = False,
289
- output_attentions: bool = False,
290
- ):
291
- fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
292
- num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
293
- # 3 x [batch_size, seq_length, num_heads, head_dim]
294
- (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
295
-
296
- batch_size, query_length, _, _ = query_layer.shape
297
-
298
- query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
299
- key_layer = key_layer.transpose(1, 2).reshape(
300
- batch_size * num_kv_heads,
301
- query_length,
302
- self.head_dim,
303
- )
304
- value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
305
-
306
- past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
307
- query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
308
-
309
- if layer_past is not None:
310
- past_key, past_value = layer_past
311
- # concatenate along seq_length dimension:
312
- # - key: [batch_size * self.num_heads, kv_length, head_dim]
313
- # - value: [batch_size * self.num_heads, kv_length, head_dim]
314
- key_layer = torch.cat((past_key, key_layer), dim=1)
315
- value_layer = torch.cat((past_value, value_layer), dim=1)
316
-
317
- _, kv_length, _ = key_layer.shape
318
- if use_cache:
319
- present = (key_layer, value_layer)
320
- else:
321
- present = None
322
-
323
- attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
324
-
325
- query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
326
- key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
327
- value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
328
-
329
- if alibi is None:
330
- if output_attentions:
331
- # F.scaled_dot_product_attention doesn't return the attention weights, so we have
332
- # to do it by hand if we want them
333
- attention_scores = query_layer_ @ key_layer_.transpose(-1, -2)
334
- attention_scores /= math.sqrt(self.head_dim)
335
-
336
- attention_scores = F.softmax(
337
- attention_scores + attention_mask_float, dim=-1, dtype=hidden_states.dtype
338
- )
339
- attn_output = attention_scores @ value_layer_
340
- else:
341
- attn_output = F.scaled_dot_product_attention(
342
- query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False
343
- )
344
- attention_scores = None
345
-
346
- attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
347
- attn_output = attn_output.permute(0, 2, 1, 3)
348
- attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
349
-
350
- output_tensor = self.dense(attn_output)
351
-
352
- if output_attentions:
353
- return output_tensor, present, attention_scores
354
- else:
355
- return output_tensor, present
356
-
357
- else:
358
- matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
359
-
360
- # change view to [batch_size, num_heads, q_length, kv_length]
361
- attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
362
-
363
- # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
364
- input_dtype = attention_scores.dtype
365
- # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
366
- if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
367
- attention_scores = attention_scores.to(torch.float32)
368
- # Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
369
- # adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically
370
- # equivalent and more performant, but there might be a numerical difference. If you're reading this
371
- # and you'd like to experiment and maybe file a PR, feel free!
372
- attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
373
- attention_logits *= self.inv_norm_factor
374
- attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype)
375
- # [batch_size, num_heads, q_length, kv_length]
376
- attention_probs = self.attention_dropout(attention_probs)
377
-
378
- if head_mask is not None:
379
- attention_probs = attention_probs * head_mask
380
-
381
- # change view [batch_size, num_heads, q_length, kv_length]
382
- attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
383
-
384
- # matmul: [batch_size * num_heads, q_length, head_dim]
385
- context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
386
-
387
- # change view [batch_size, num_heads, q_length, head_dim]
388
- context_layer = self._merge_heads(context_layer)
389
-
390
- output_tensor = self.dense(context_layer)
391
-
392
- if output_attentions:
393
- return output_tensor, present, attention_probs
394
- else:
395
- return output_tensor, present
396
-
397
-
398
- class FalconMLP(nn.Module):
399
- def __init__(self, config: RWConfig):
400
- super().__init__()
401
- hidden_size = config.hidden_size
402
-
403
- self.dense_h_to_4h = FalconLinear(hidden_size, 4 * hidden_size, bias=config.bias)
404
- self.act = nn.GELU()
405
- self.dense_4h_to_h = FalconLinear(4 * hidden_size, hidden_size, bias=config.bias)
406
- self.hidden_dropout = config.hidden_dropout
407
-
408
- def forward(self, x: torch.Tensor) -> torch.Tensor:
409
- x = self.act(self.dense_h_to_4h(x))
410
- x = self.dense_4h_to_h(x)
411
- return x
412
-
413
-
414
- class FalconDecoderLayer(nn.Module):
415
- def __init__(self, config: RWConfig):
416
- super().__init__()
417
- hidden_size = config.hidden_size
418
- self.num_heads = config.num_attention_heads
419
- self.self_attention = FalconAttention(config)
420
- self.mlp = FalconMLP(config)
421
- self.hidden_dropout = config.hidden_dropout
422
- self.config = config
423
-
424
- if config.new_decoder_architecture:
425
- # The layer norm before self-attention
426
- self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
427
- # The layer norm before the MLP
428
- self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
429
- else:
430
- self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
431
- if not config.parallel_attn:
432
- self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
433
-
434
- def forward(
435
- self,
436
- hidden_states: torch.Tensor,
437
- alibi: Optional[torch.Tensor],
438
- attention_mask: torch.Tensor,
439
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
440
- head_mask: Optional[torch.Tensor] = None,
441
- use_cache: bool = False,
442
- output_attentions: bool = False,
443
- ):
444
- residual = hidden_states
445
-
446
- if self.config.new_decoder_architecture:
447
- attention_layernorm_out = self.ln_attn(hidden_states)
448
- mlp_layernorm_out = self.ln_mlp(hidden_states)
449
- else:
450
- attention_layernorm_out = self.input_layernorm(hidden_states)
451
-
452
- # Self attention.
453
- attn_outputs = self.self_attention(
454
- attention_layernorm_out,
455
- layer_past=layer_past,
456
- attention_mask=attention_mask,
457
- alibi=alibi,
458
- head_mask=head_mask,
459
- use_cache=use_cache,
460
- output_attentions=output_attentions,
461
- )
462
-
463
- attention_output = attn_outputs[0]
464
-
465
- if not self.config.new_decoder_architecture:
466
- if self.config.parallel_attn:
467
- mlp_layernorm_out = attention_layernorm_out
468
- else:
469
- residual = dropout_add(
470
- attention_output, residual, self.config.attention_dropout, training=self.training
471
- )
472
- mlp_layernorm_out = self.post_attention_layernorm(residual)
473
-
474
- outputs = attn_outputs[1:]
475
-
476
- # MLP.
477
- mlp_output = self.mlp(mlp_layernorm_out)
478
-
479
- if self.config.new_decoder_architecture or self.config.parallel_attn:
480
- mlp_output += attention_output
481
-
482
- output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
483
-
484
- if use_cache:
485
- outputs = (output,) + outputs
486
- else:
487
- outputs = (output,) + outputs[1:]
488
-
489
- return outputs # hidden_states, present, attentions
490
-
491
-
492
- FALCON_START_DOCSTRING = r"""
493
-
494
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
495
- library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
496
-
497
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
498
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
499
- and behavior.
500
-
501
- Parameters:
502
- config ([`RWConfig`]): Model configuration class with all the parameters of the model.
503
- Initializing with a config file does not load the weights associated with the model, only the
504
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
505
- """
506
-
507
- FALCON_INPUTS_DOCSTRING = r"""
508
- Args:
509
- input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
510
- `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
511
- (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
512
-
513
- If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
514
- `input_ids`.
515
-
516
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
517
- [`PreTrainedTokenizer.__call__`] for details.
518
-
519
- [What are input IDs?](../glossary#input-ids)
520
- past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.num_hidden_layers`):
521
- Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
522
- `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
523
- their past given to this model should not be passed as `input_ids` as they have already been computed.
524
-
525
- Each element of `past_key_values` is a tuple (past_key, past_value):
526
- - past_key: [batch_size * num_heads, head_dim, kv_length]
527
- - past_value: [batch_size * num_heads, kv_length, head_dim]
528
- attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
529
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
530
-
531
- - 1 for tokens that are **not masked**,
532
- - 0 for tokens that are **masked**.
533
-
534
- [What are attention masks?](../glossary#attention-mask)
535
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
536
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
537
-
538
- - 1 indicates the head is **not masked**,
539
- - 0 indicates the head is **masked**.
540
-
541
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
542
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
543
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
544
- model's internal embedding lookup matrix.
545
-
546
- If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
547
- `past_key_values`).
548
- use_cache (`bool`, *optional*):
549
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
550
- `past_key_values`).
551
- output_attentions (`bool`, *optional*):
552
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
553
- tensors for more detail.
554
- output_hidden_states (`bool`, *optional*):
555
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
556
- more detail.
557
- return_dict (`bool`, *optional*):
558
- Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
559
- """
560
-
561
-
562
- class RWPreTrainedModel(PreTrainedModel):
563
- """
564
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
565
- models.
566
- """
567
-
568
- config_class = RWConfig
569
- base_model_prefix = "transformer"
570
- supports_gradient_checkpointing = True
571
- _no_split_modules = ["FalconDecoderLayer"]
572
-
573
- def __init__(self, *inputs, **kwargs):
574
- super().__init__(*inputs, **kwargs)
575
-
576
- def _init_weights(self, module: nn.Module):
577
- """Initialize the weights."""
578
- if isinstance(module, nn.Linear) or isinstance(module, FalconLinear):
579
- # Slightly different from the TF version which uses truncated_normal for initialization
580
- # cf https://github.com/pytorch/pytorch/pull/5617
581
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
582
- if module.bias is not None:
583
- module.bias.data.zero_()
584
- elif isinstance(module, nn.Embedding):
585
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
586
- if module.padding_idx is not None:
587
- module.weight.data[module.padding_idx].zero_()
588
- elif isinstance(module, LayerNorm):
589
- module.bias.data.zero_()
590
- module.weight.data.fill_(1.0)
591
-
592
- # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._set_gradient_checkpointing with BloomModel->RWModel
593
- def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
594
- if isinstance(module, RWModel):
595
- module.gradient_checkpointing = value
596
-
597
- @staticmethod
598
- def _convert_cache_to_standard_format(
599
- past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
600
- ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
601
- """
602
- Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
603
- num_heads, ...]))
604
- """
605
- batch_size_times_num_heads, kv_length, head_dim = past_key_value[0][0].shape
606
- # [batch_size * self.num_heads, kv_length, head_dim] -> [batch_size, num_heads, kv_length, head_dim]
607
- # Note that don't want to use self.num_attention_heads because the number of heads may vary depending
608
- # on whether we use multi_query attention.
609
- num_heads = batch_size_times_num_heads // batch_size
610
- return tuple(
611
- (
612
- layer_past[0].view(batch_size, num_heads, kv_length, head_dim),
613
- layer_past[1].view(batch_size, num_heads, kv_length, head_dim),
614
- )
615
- for layer_past in past_key_value
616
- )
617
-
618
- @staticmethod
619
- def _convert_to_rw_cache(
620
- past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
621
- ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
622
- batch_size, num_heads, kv_length, head_dim = past_key_value[0][0].shape
623
- batch_size_times_num_heads = batch_size * num_heads
624
- # [batch_size, num_heads, kv_length, head_dim] -> [batch_size * num_heads, kv_length, head_dim]
625
- return tuple(
626
- (
627
- layer_past[0].view(batch_size_times_num_heads, kv_length, head_dim),
628
- layer_past[1].view(batch_size_times_num_heads, kv_length, head_dim),
629
- )
630
- for layer_past in past_key_value
631
- )
632
-
633
-
634
- @add_start_docstrings(
635
- "The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.",
636
- FALCON_START_DOCSTRING,
637
- )
638
- class RWModel(RWPreTrainedModel):
639
- def __init__(self, config: RWConfig):
640
- super().__init__(config)
641
-
642
- self.embed_dim = config.hidden_size
643
- self.num_heads = config.num_attention_heads
644
- self.use_alibi = config.alibi
645
-
646
- # Embedding + LN Embedding
647
- self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
648
-
649
- # Transformer blocks
650
- self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)])
651
-
652
- # Final Layer Norm
653
- self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
654
-
655
- self.gradient_checkpointing = False
656
-
657
- # Initialize weights and apply final processing
658
- self.post_init()
659
-
660
- def get_input_embeddings(self):
661
- return self.word_embeddings
662
-
663
- @staticmethod
664
- def _prepare_attn_mask(
665
- attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
666
- ) -> torch.BoolTensor:
667
- # Create a causal mask
668
- # The attention mask we receive as input should cover the whole extended sequence, including any past
669
- # cache, so its shape should be [batch_size, seq_length + past_key_values_length]
670
- # The output shape will be [batch_size, 1, seq_length, seq_length + past_key_values_length]
671
- if input_shape[1] + past_key_values_length != attention_mask.shape[1]:
672
- raise ValueError(
673
- "Attention mask shape should be (batch_size, seq_length + past_key_values_length)"
674
- f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length"
675
- f" {past_key_values_length}."
676
- )
677
- combined_attention_mask = None
678
- device = attention_mask.device
679
- _, seq_length = input_shape
680
-
681
- if seq_length > 1:
682
- combined_attention_mask = _make_causal_mask(
683
- input_shape, device=device, past_key_values_length=past_key_values_length
684
- )
685
-
686
- # [batch_size, seq_length + past_key_values_length] -> [batch_size, 1, seq_length, seq_length + past_key_values_length]
687
- expanded_attn_mask = _expand_mask(attention_mask, past_key_values_length=past_key_values_length)
688
- combined_attention_mask = (
689
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
690
- )
691
-
692
- return combined_attention_mask
693
-
694
- def set_input_embeddings(self, new_embeddings: torch.Tensor):
695
- self.word_embeddings = new_embeddings
696
-
697
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
698
- @add_code_sample_docstrings(
699
- checkpoint=_CHECKPOINT_FOR_DOC,
700
- output_type=BaseModelOutputWithPastAndCrossAttentions,
701
- config_class=_CONFIG_FOR_DOC,
702
- )
703
- def forward(
704
- self,
705
- input_ids: Optional[torch.LongTensor] = None,
706
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
707
- attention_mask: Optional[torch.Tensor] = None,
708
- head_mask: Optional[torch.LongTensor] = None,
709
- inputs_embeds: Optional[torch.LongTensor] = None,
710
- use_cache: Optional[bool] = None,
711
- output_attentions: Optional[bool] = None,
712
- output_hidden_states: Optional[bool] = None,
713
- return_dict: Optional[bool] = None,
714
- ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
715
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
716
- output_hidden_states = (
717
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
718
- )
719
- use_cache = use_cache if use_cache is not None else self.config.use_cache
720
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
721
-
722
- if input_ids is not None and inputs_embeds is not None:
723
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
724
- elif input_ids is not None:
725
- batch_size, seq_length = input_ids.shape
726
- elif inputs_embeds is not None:
727
- batch_size, seq_length, _ = inputs_embeds.shape
728
- else:
729
- raise ValueError("You have to specify either input_ids or inputs_embeds")
730
-
731
- if past_key_values is None:
732
- past_key_values = tuple([None] * len(self.h))
733
- else:
734
- past_key_values = self._convert_to_rw_cache(past_key_values)
735
-
736
- # Prepare head mask if needed
737
- # 1.0 in head_mask indicate we keep the head
738
- # attention_probs has shape batch_size x num_heads x N x N
739
- # head_mask has shape n_layer x batch x num_heads x N x N
740
- head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
741
-
742
- if inputs_embeds is None:
743
- inputs_embeds = self.word_embeddings(input_ids)
744
-
745
- hidden_states = inputs_embeds
746
-
747
- presents = () if use_cache else None
748
- all_self_attentions = () if output_attentions else None
749
- all_hidden_states = () if output_hidden_states else None
750
-
751
- # Compute alibi tensor: check build_alibi_tensor documentation
752
- past_key_values_length = 0
753
- if past_key_values[0] is not None:
754
- past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
755
- if attention_mask is None:
756
- attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
757
- else:
758
- attention_mask = attention_mask.to(hidden_states.device)
759
-
760
- if self.use_alibi:
761
- alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
762
- else:
763
- alibi = None
764
-
765
- causal_mask = self._prepare_attn_mask(
766
- attention_mask,
767
- input_shape=(batch_size, seq_length),
768
- past_key_values_length=past_key_values_length,
769
- )
770
-
771
- for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
772
- if output_hidden_states:
773
- all_hidden_states = all_hidden_states + (hidden_states,)
774
-
775
- if self.gradient_checkpointing and self.training:
776
- if use_cache:
777
- logger.warning(
778
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
779
- )
780
- use_cache = False
781
-
782
- def create_custom_forward(module):
783
- def custom_forward(*inputs):
784
- # None for past_key_value
785
- return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
786
-
787
- return custom_forward
788
-
789
- outputs = torch.utils.checkpoint.checkpoint(
790
- create_custom_forward(block),
791
- hidden_states,
792
- alibi,
793
- causal_mask,
794
- head_mask[i],
795
- )
796
- else:
797
- outputs = block(
798
- hidden_states,
799
- layer_past=layer_past,
800
- attention_mask=causal_mask,
801
- head_mask=head_mask[i],
802
- use_cache=use_cache,
803
- output_attentions=output_attentions,
804
- alibi=alibi,
805
- )
806
-
807
- hidden_states = outputs[0]
808
- if use_cache is True:
809
- presents = presents + (outputs[1],)
810
-
811
- if output_attentions:
812
- all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
813
-
814
- # Add last hidden state
815
- hidden_states = self.ln_f(hidden_states)
816
-
817
- if output_hidden_states:
818
- all_hidden_states = all_hidden_states + (hidden_states,)
819
-
820
- if presents is not None:
821
- presents = self._convert_cache_to_standard_format(presents, batch_size)
822
-
823
- if not return_dict:
824
- return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
825
-
826
- return BaseModelOutputWithPastAndCrossAttentions(
827
- last_hidden_state=hidden_states,
828
- past_key_values=presents,
829
- hidden_states=all_hidden_states,
830
- attentions=all_self_attentions,
831
- )
832
-
833
-
834
- @add_start_docstrings(
835
- "The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).",
836
- FALCON_START_DOCSTRING,
837
- )
838
- class RWForCausalLM(RWPreTrainedModel):
839
- _tied_weights_keys = ["lm_head.weight"]
840
-
841
- def __init__(self, config: RWConfig):
842
- super().__init__(config)
843
- self.transformer = RWModel(config)
844
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
845
-
846
- # Initialize weights and apply final processing
847
- self.post_init()
848
-
849
- def get_output_embeddings(self):
850
- return self.lm_head
851
-
852
- def set_output_embeddings(self, new_embeddings: torch.Tensor):
853
- self.lm_head = new_embeddings
854
-
855
- def prepare_inputs_for_generation(
856
- self,
857
- input_ids: torch.LongTensor,
858
- past_key_values: Optional[torch.Tensor] = None,
859
- attention_mask: Optional[torch.Tensor] = None,
860
- **kwargs,
861
- ) -> dict:
862
- if past_key_values is not None:
863
- input_ids = input_ids[:, -1:]
864
-
865
- return {
866
- "input_ids": input_ids,
867
- "past_key_values": past_key_values,
868
- "use_cache": kwargs.get("use_cache"),
869
- "attention_mask": attention_mask,
870
- }
871
-
872
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
873
- @add_code_sample_docstrings(
874
- checkpoint=_CHECKPOINT_FOR_DOC,
875
- output_type=CausalLMOutputWithCrossAttentions,
876
- config_class=_CONFIG_FOR_DOC,
877
- )
878
- def forward(
879
- self,
880
- input_ids: Optional[torch.LongTensor] = None,
881
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
882
- attention_mask: Optional[torch.Tensor] = None,
883
- head_mask: Optional[torch.Tensor] = None,
884
- inputs_embeds: Optional[torch.Tensor] = None,
885
- labels: Optional[torch.Tensor] = None,
886
- use_cache: Optional[bool] = None,
887
- output_attentions: Optional[bool] = None,
888
- output_hidden_states: Optional[bool] = None,
889
- return_dict: Optional[bool] = None,
890
- ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
891
- r"""
892
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
893
- Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
894
- `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
895
- are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
896
- """
897
-
898
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
899
-
900
- transformer_outputs = self.transformer(
901
- input_ids,
902
- past_key_values=past_key_values,
903
- attention_mask=attention_mask,
904
- head_mask=head_mask,
905
- inputs_embeds=inputs_embeds,
906
- use_cache=use_cache,
907
- output_attentions=output_attentions,
908
- output_hidden_states=output_hidden_states,
909
- return_dict=return_dict,
910
- )
911
- hidden_states = transformer_outputs[0]
912
-
913
- lm_logits = self.lm_head(hidden_states)
914
-
915
- loss = None
916
- if labels is not None:
917
- # Shift so that tokens < n predict n
918
- shift_logits = lm_logits[..., :-1, :].contiguous()
919
- shift_labels = labels[..., 1:].contiguous()
920
- batch_size, seq_length, vocab_size = shift_logits.shape
921
- # Flatten the tokens
922
- loss_fct = CrossEntropyLoss()
923
- loss = loss_fct(
924
- shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
925
- )
926
-
927
- if not return_dict:
928
- output = (lm_logits,) + transformer_outputs[1:]
929
- return ((loss,) + output) if loss is not None else output
930
-
931
- return CausalLMOutputWithCrossAttentions(
932
- loss=loss,
933
- logits=lm_logits,
934
- past_key_values=transformer_outputs.past_key_values,
935
- hidden_states=transformer_outputs.hidden_states,
936
- attentions=transformer_outputs.attentions,
937
- )
938
-
939
- def _reorder_cache(
940
- self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
941
- ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
942
- """
943
- This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
944
- [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
945
- beam_idx at every generation step.
946
-
947
- Output shares the same memory storage as `past`.
948
- """
949
-
950
- # Get a copy of `beam_idx` on all the devices where we need those indices.
951
- device_to_beam_idx = {
952
- past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
953
- }
954
- reordered_past = tuple(
955
- (
956
- layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
957
- layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
958
- )
959
- for layer_past in past
960
- )
961
- return reordered_past
962
-
963
-
964
- @add_start_docstrings(
965
- """
966
- The Falcon Model transformer with a sequence classification head on top (linear layer).
967
-
968
- [`RWForSequenceClassification`] uses the last token in order to do the classification, as other causal models
969
- (e.g. GPT-1) do.
970
-
971
- Since it does classification on the last token, it requires to know the position of the last token. If a
972
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
973
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
974
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
975
- each row of the batch).
976
- """,
977
- FALCON_START_DOCSTRING,
978
- )
979
- class RWForSequenceClassification(RWPreTrainedModel):
980
- def __init__(self, config: RWConfig):
981
- super().__init__(config)
982
- self.num_labels = config.num_labels
983
- self.transformer = RWModel(config)
984
- self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
985
-
986
- # Initialize weights and apply final processing
987
- self.post_init()
988
-
989
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
990
- @add_code_sample_docstrings(
991
- checkpoint=_CHECKPOINT_FOR_DOC,
992
- output_type=SequenceClassifierOutputWithPast,
993
- config_class=_CONFIG_FOR_DOC,
994
- )
995
- def forward(
996
- self,
997
- input_ids: Optional[torch.LongTensor] = None,
998
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
999
- attention_mask: Optional[torch.Tensor] = None,
1000
- head_mask: Optional[torch.Tensor] = None,
1001
- inputs_embeds: Optional[torch.Tensor] = None,
1002
- labels: Optional[torch.Tensor] = None,
1003
- use_cache: Optional[bool] = None,
1004
- output_attentions: Optional[bool] = None,
1005
- output_hidden_states: Optional[bool] = None,
1006
- return_dict: Optional[bool] = None,
1007
- ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
1008
- r"""
1009
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1010
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1011
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1012
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1013
- """
1014
-
1015
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1016
-
1017
- transformer_outputs = self.transformer(
1018
- input_ids,
1019
- past_key_values=past_key_values,
1020
- attention_mask=attention_mask,
1021
- head_mask=head_mask,
1022
- inputs_embeds=inputs_embeds,
1023
- use_cache=use_cache,
1024
- output_attentions=output_attentions,
1025
- output_hidden_states=output_hidden_states,
1026
- return_dict=return_dict,
1027
- )
1028
-
1029
- hidden_states = transformer_outputs[0]
1030
- logits = self.score(hidden_states)
1031
-
1032
- if input_ids is not None:
1033
- batch_size = input_ids.shape[0]
1034
- else:
1035
- batch_size = inputs_embeds.shape[0]
1036
-
1037
- if self.config.pad_token_id is None and batch_size != 1:
1038
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1039
- if self.config.pad_token_id is None:
1040
- sequence_lengths = -1
1041
- else:
1042
- if input_ids is not None:
1043
- sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1
1044
- else:
1045
- sequence_lengths = -1
1046
- logger.warning(
1047
- f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1048
- "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1049
- )
1050
-
1051
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1052
-
1053
- loss = None
1054
- if labels is not None:
1055
- if self.config.problem_type is None:
1056
- if self.num_labels == 1:
1057
- self.config.problem_type = "regression"
1058
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1059
- self.config.problem_type = "single_label_classification"
1060
- else:
1061
- self.config.problem_type = "multi_label_classification"
1062
-
1063
- if self.config.problem_type == "regression":
1064
- loss_fct = MSELoss()
1065
- if self.num_labels == 1:
1066
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1067
- else:
1068
- loss = loss_fct(pooled_logits, labels)
1069
- elif self.config.problem_type == "single_label_classification":
1070
- loss_fct = CrossEntropyLoss()
1071
- loss = loss_fct(pooled_logits, labels)
1072
- elif self.config.problem_type == "multi_label_classification":
1073
- loss_fct = BCEWithLogitsLoss()
1074
- loss = loss_fct(pooled_logits, labels)
1075
- if not return_dict:
1076
- output = (pooled_logits,) + transformer_outputs[1:]
1077
- return ((loss,) + output) if loss is not None else output
1078
-
1079
- return SequenceClassifierOutputWithPast(
1080
- loss=loss,
1081
- logits=pooled_logits,
1082
- past_key_values=transformer_outputs.past_key_values,
1083
- hidden_states=transformer_outputs.hidden_states,
1084
- attentions=transformer_outputs.attentions,
1085
- )
1086
-
1087
-
1088
- @add_start_docstrings(
1089
- """
1090
- Falcon Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1091
- Named-Entity-Recognition (NER) tasks.
1092
- """,
1093
- FALCON_START_DOCSTRING,
1094
- )
1095
- class RWForTokenClassification(RWPreTrainedModel):
1096
- def __init__(self, config: RWConfig):
1097
- super().__init__(config)
1098
- self.num_labels = config.num_labels
1099
-
1100
- self.transformer = RWModel(config)
1101
- if getattr(config, "classifier_dropout", None) is not None:
1102
- classifier_dropout = config.classifier_dropout
1103
- elif getattr(config, "hidden_dropout", None) is not None:
1104
- classifier_dropout = config.hidden_dropout
1105
- else:
1106
- classifier_dropout = 0.1
1107
- self.dropout = nn.Dropout(classifier_dropout)
1108
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1109
-
1110
- # Initialize weights and apply final processing
1111
- self.post_init()
1112
-
1113
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1114
- @add_code_sample_docstrings(
1115
- checkpoint=_CHECKPOINT_FOR_DOC,
1116
- output_type=TokenClassifierOutput,
1117
- config_class=_CONFIG_FOR_DOC,
1118
- )
1119
- def forward(
1120
- self,
1121
- input_ids: Optional[torch.LongTensor] = None,
1122
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1123
- attention_mask: Optional[torch.Tensor] = None,
1124
- head_mask: Optional[torch.Tensor] = None,
1125
- inputs_embeds: Optional[torch.Tensor] = None,
1126
- labels: Optional[torch.Tensor] = None,
1127
- use_cache: Optional[bool] = None,
1128
- output_attentions: Optional[bool] = None,
1129
- output_hidden_states: Optional[bool] = None,
1130
- return_dict: Optional[bool] = None,
1131
- ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1132
- r"""
1133
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1134
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1135
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1136
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1137
- """
1138
-
1139
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1140
-
1141
- transformer_outputs = self.transformer(
1142
- input_ids,
1143
- past_key_values=past_key_values,
1144
- attention_mask=attention_mask,
1145
- head_mask=head_mask,
1146
- inputs_embeds=inputs_embeds,
1147
- use_cache=use_cache,
1148
- output_attentions=output_attentions,
1149
- output_hidden_states=output_hidden_states,
1150
- return_dict=return_dict,
1151
- )
1152
-
1153
- hidden_states = transformer_outputs[0]
1154
- hidden_states = self.dropout(hidden_states)
1155
- logits = self.classifier(hidden_states)
1156
-
1157
- loss = None
1158
- if labels is not None:
1159
- batch_size, seq_length = labels.shape
1160
- loss_fct = CrossEntropyLoss()
1161
- loss = loss_fct(
1162
- logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1163
- )
1164
-
1165
- if not return_dict:
1166
- output = (logits,) + transformer_outputs[2:]
1167
- return ((loss,) + output) if loss is not None else output
1168
-
1169
- return TokenClassifierOutput(
1170
- loss=loss,
1171
- logits=logits,
1172
- hidden_states=transformer_outputs.hidden_states,
1173
- attentions=transformer_outputs.attentions,
1174
- )
1175
-
1176
-
1177
- @add_start_docstrings(
1178
- """
1179
- The Falcon Model transformer with a span classification head on top for extractive question-answering tasks like
1180
- SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1181
- """,
1182
- FALCON_START_DOCSTRING,
1183
- )
1184
- class RWForQuestionAnswering(RWPreTrainedModel):
1185
- def __init__(self, config):
1186
- super().__init__(config)
1187
- self.transformer = RWModel(config)
1188
- self.qa_outputs = nn.Linear(config.hidden_size, 2)
1189
-
1190
- # Initialize weights and apply final processing
1191
- self.post_init()
1192
-
1193
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1194
- def forward(
1195
- self,
1196
- input_ids: Optional[torch.LongTensor] = None,
1197
- attention_mask: Optional[torch.FloatTensor] = None,
1198
- head_mask: Optional[torch.FloatTensor] = None,
1199
- inputs_embeds: Optional[torch.FloatTensor] = None,
1200
- start_positions: Optional[torch.LongTensor] = None,
1201
- end_positions: Optional[torch.LongTensor] = None,
1202
- output_attentions: Optional[bool] = None,
1203
- output_hidden_states: Optional[bool] = None,
1204
- return_dict: Optional[bool] = None,
1205
- ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1206
- r"""
1207
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1208
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
1209
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1210
- are not taken into account for computing the loss.
1211
- end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1212
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
1213
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1214
- are not taken into account for computing the loss.
1215
- """
1216
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1217
-
1218
- outputs = self.transformer(
1219
- input_ids,
1220
- attention_mask=attention_mask,
1221
- head_mask=head_mask,
1222
- inputs_embeds=inputs_embeds,
1223
- output_attentions=output_attentions,
1224
- output_hidden_states=output_hidden_states,
1225
- return_dict=return_dict,
1226
- )
1227
-
1228
- sequence_output = outputs[0]
1229
-
1230
- logits = self.qa_outputs(sequence_output)
1231
- start_logits, end_logits = logits.split(1, dim=-1)
1232
- start_logits = start_logits.squeeze(-1).contiguous()
1233
- end_logits = end_logits.squeeze(-1).contiguous()
1234
-
1235
- total_loss = None
1236
- if start_positions is not None and end_positions is not None:
1237
- # If we are on multi-GPU, split add a dimension
1238
- if len(start_positions.size()) > 1:
1239
- start_positions = start_positions.squeeze(-1)
1240
- if len(end_positions.size()) > 1:
1241
- end_positions = end_positions.squeeze(-1)
1242
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
1243
- ignored_index = start_logits.size(1)
1244
- start_positions = start_positions.clamp(0, ignored_index)
1245
- end_positions = end_positions.clamp(0, ignored_index)
1246
-
1247
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1248
- start_loss = loss_fct(start_logits, start_positions)
1249
- end_loss = loss_fct(end_logits, end_positions)
1250
- total_loss = (start_loss + end_loss) / 2
1251
-
1252
- if not return_dict:
1253
- output = (start_logits, end_logits) + outputs[2:]
1254
- return ((total_loss,) + output) if total_loss is not None else output
1255
-
1256
- return QuestionAnsweringModelOutput(
1257
- loss=total_loss,
1258
- start_logits=start_logits,
1259
- end_logits=end_logits,
1260
- hidden_states=outputs.hidden_states,
1261
- attentions=outputs.attentions,
1262
- )