BioMike commited on
Commit
b605b7a
1 Parent(s): 91fdc93

Delete modeling

Browse files
modeling/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .model import MT5ForConditionalGeneration
2
- from .config import MT5Config
 
 
 
modeling/config.py DELETED
@@ -1,133 +0,0 @@
1
- from transformers import PretrainedConfig
2
-
3
- class MT5Config(PretrainedConfig):
4
- r"""
5
- This is the configuration class to store the configuration of a [`MT5Model`] or a [`TFMT5Model`]. It is used to
6
- instantiate a mT5 model according to the specified arguments, defining the model architecture. Instantiating a
7
- configuration with the defaults will yield a similar configuration to that of the mT5
8
- [google/mt5-small](https://huggingface.co/google/mt5-small) architecture.
9
-
10
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
11
- documentation from [`PretrainedConfig`] for more information.
12
-
13
- Arguments:
14
- vocab_size (`int`, *optional*, defaults to 250112):
15
- Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the
16
- `inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`].
17
- d_model (`int`, *optional*, defaults to 512):
18
- Size of the encoder layers and the pooler layer.
19
- d_kv (`int`, *optional*, defaults to 64):
20
- Size of the key, query, value projections per attention head. In the conventional context, it is typically expected that `d_kv` has to be equal to `d_model // num_heads`.
21
- But in the architecture of mt5-small, `d_kv` is not equal to `d_model //num_heads`. The `inner_dim` of the projection layer will be defined as `num_heads * d_kv`.
22
- d_ff (`int`, *optional*, defaults to 1024):
23
- Size of the intermediate feed forward layer in each `T5Block`.
24
- num_layers (`int`, *optional*, defaults to 8):
25
- Number of hidden layers in the Transformer encoder.
26
- num_decoder_layers (`int`, *optional*):
27
- Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
28
- num_heads (`int`, *optional*, defaults to 6):
29
- Number of attention heads for each attention layer in the Transformer encoder.
30
- relative_attention_num_buckets (`int`, *optional*, defaults to 32):
31
- The number of buckets to use for each attention layer.
32
- relative_attention_max_distance (`int`, *optional*, defaults to 128):
33
- The maximum distance of the longer sequences for the bucket separation.
34
- dropout_rate (`float`, *optional*, defaults to 0.1):
35
- The ratio for all dropout layers.
36
- classifier_dropout (`float`, *optional*, defaults to 0.0):
37
- The dropout ratio for classifier.
38
- layer_norm_eps (`float`, *optional*, defaults to 1e-6):
39
- The epsilon used by the layer normalization layers.
40
- initializer_factor (`float`, *optional*, defaults to 1):
41
- A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
42
- testing).
43
- feed_forward_proj (`string`, *optional*, defaults to `"gated-gelu"`):
44
- Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`.
45
- use_cache (`bool`, *optional*, defaults to `True`):
46
- Whether or not the model should return the last key/values attentions (not used by all models).
47
- """
48
-
49
- model_type = "mt5"
50
- keys_to_ignore_at_inference = ["past_key_values"]
51
-
52
- def __init__(
53
- self,
54
- encoder_vocab_size=250112,
55
- decoder_vocab_size=250112,
56
- shared_embedding=False,
57
- d_model=256,
58
- d_kv=64,
59
- d_ff=512,
60
- num_layers=4,
61
- num_decoder_layers=None,
62
- num_heads=3,
63
- relative_attention_num_buckets=32,
64
- relative_attention_max_distance=128,
65
- dropout_rate=0.1,
66
- layer_norm_epsilon=1e-6,
67
- initializer_factor=1.0,
68
- feed_forward_proj="gated-gelu",
69
- is_encoder_decoder=True,
70
- use_cache=True,
71
- tokenizer_class="ChemTokenizers.SMILES_IUPAC_FAST.FastTokenizer",
72
- tie_word_embeddings=False,
73
- pad_token_id=0,
74
- eos_token_id=1,
75
- decoder_start_token_id=2,
76
- classifier_dropout=0.0,
77
- **kwargs,
78
- ):
79
- super().__init__(
80
- is_encoder_decoder=is_encoder_decoder,
81
- tokenizer_class=tokenizer_class,
82
- tie_word_embeddings=tie_word_embeddings,
83
- pad_token_id=pad_token_id,
84
- eos_token_id=eos_token_id,
85
- decoder_start_token_id=decoder_start_token_id,
86
- **kwargs,
87
- )
88
- self.encoder_vocab_size = encoder_vocab_size
89
- self.decoder_vocab_size = decoder_vocab_size
90
- self.shared_embedding = shared_embedding
91
- self.d_model = d_model
92
- self.d_kv = d_kv
93
- self.d_ff = d_ff
94
- self.num_layers = num_layers
95
- self.num_decoder_layers = (
96
- num_decoder_layers if num_decoder_layers is not None else self.num_layers
97
- ) # default = symmetry
98
- self.num_heads = num_heads
99
- self.relative_attention_num_buckets = relative_attention_num_buckets
100
- self.relative_attention_max_distance = relative_attention_max_distance
101
- self.dropout_rate = dropout_rate
102
- self.classifier_dropout = classifier_dropout
103
- self.layer_norm_epsilon = layer_norm_epsilon
104
- self.initializer_factor = initializer_factor
105
- self.feed_forward_proj = feed_forward_proj
106
- self.use_cache = use_cache
107
-
108
- act_info = self.feed_forward_proj.split("-")
109
- self.dense_act_fn = act_info[-1]
110
- self.is_gated_act = act_info[0] == "gated"
111
-
112
- if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
113
- raise ValueError(
114
- f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. "
115
- "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
116
- "'gated-gelu' or 'relu'"
117
- )
118
-
119
- # for backwards compatibility
120
- if feed_forward_proj == "gated-gelu":
121
- self.dense_act_fn = "gelu_new"
122
-
123
- @property
124
- def hidden_size(self):
125
- return self.d_model
126
-
127
- @property
128
- def num_attention_heads(self):
129
- return self.num_heads
130
-
131
- @property
132
- def num_hidden_layers(self):
133
- return self.num_layers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling/docstrings.py DELETED
@@ -1,217 +0,0 @@
1
- PARALLELIZE_DOCSTRING = r"""
2
- This is an experimental feature and is a subject to change at a moment's notice.
3
-
4
- Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
5
- it will evenly distribute blocks across all devices.
6
-
7
- Args:
8
- device_map (`Dict[int, list]`, optional, defaults to None):
9
- A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
10
- automatically mapped to the first device (for esoteric reasons). That means that the first device should
11
- have fewer attention modules mapped to it than other devices. For reference, the mt5 models have the
12
- following number of attention modules:
13
-
14
- - mt5-small: 6
15
- - mt5-base: 12
16
- - mt5-large: 24
17
- - mt5-xl: 24
18
- - mt5-xxl: 24
19
-
20
- Example:
21
-
22
- ```python
23
- # Here is an example of a device map on a machine with 4 GPUs using mt5-xl, which has a total of 24 attention modules:
24
- model = MT5ForConditionalGeneration.from_pretrained("mt5-xl")
25
- device_map = {
26
- 0: [0, 1, 2],
27
- 1: [3, 4, 5, 6, 7, 8, 9],
28
- 2: [10, 11, 12, 13, 14, 15, 16],
29
- 3: [17, 18, 19, 20, 21, 22, 23],
30
- }
31
- model.parallelize(device_map)
32
- ```
33
- """
34
- DEPARALLELIZE_DOCSTRING = r"""
35
- Moves the model to cpu from a model parallel state.
36
-
37
- Example:
38
-
39
- ```python
40
- # On a 4 GPU machine with mt5-xl:
41
- model = MT5ForConditionalGeneration.from_pretrained("Mt5-xl")
42
- device_map = {
43
- 0: [0, 1, 2],
44
- 1: [3, 4, 5, 6, 7, 8, 9],
45
- 2: [10, 11, 12, 13, 14, 15, 16],
46
- 3: [17, 18, 19, 20, 21, 22, 23],
47
- }
48
- model.parallelize(device_map) # Splits the model across several devices
49
- model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
50
- ```
51
- """
52
-
53
- __HEAD_MASK_WARNING_MSG = """
54
- The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
55
- `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
56
- If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
57
- num_heads)`.
58
- """
59
-
60
- MT5_START_DOCSTRING = r"""
61
-
62
- The MT5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text
63
- Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan
64
- Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a
65
- text-to-text denoising generative setting.
66
-
67
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
68
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
69
- etc.)
70
-
71
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
72
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
73
- and behavior.
74
-
75
- Parameters:
76
- config ([`MT5Config`]): Model configuration class with all the parameters of the model.
77
- Initializing with a config file does not load the weights associated with the model, only the
78
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
79
- """
80
-
81
- MT5_INPUTS_DOCSTRING = r"""
82
- Args:
83
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
84
- Indices of input sequence tokens in the vocabulary. MT5 is a model with relative position embeddings so you
85
- should be able to pad the inputs on both the right and the left.
86
-
87
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
88
- [`PreTrainedTokenizer.__call__`] for detail.
89
-
90
- [What are input IDs?](../glossary#input-ids)
91
-
92
- To know more on how to prepare `input_ids` for pretraining take a look a [MT5 Training](./mt5#training).
93
- attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
94
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
95
-
96
- - 1 for tokens that are **not masked**,
97
- - 0 for tokens that are **masked**.
98
-
99
- [What are attention masks?](../glossary#attention-mask)
100
- decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
101
- Indices of decoder input sequence tokens in the vocabulary.
102
-
103
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
104
- [`PreTrainedTokenizer.__call__`] for details.
105
-
106
- [What are decoder input IDs?](../glossary#decoder-input-ids)
107
-
108
- MT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
109
- is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
110
-
111
- To know more on how to prepare `decoder_input_ids` for pretraining take a look at [MT5
112
- Training](./mt5#training).
113
- decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
114
- Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
115
- be used by default.
116
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
117
- Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
118
- 1]`:
119
-
120
- - 1 indicates the head is **not masked**,
121
- - 0 indicates the head is **masked**.
122
-
123
- decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
124
- Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
125
- 1]`:
126
-
127
- - 1 indicates the head is **not masked**,
128
- - 0 indicates the head is **masked**.
129
-
130
- cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
131
- Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
132
- `[0, 1]`:
133
-
134
- - 1 indicates the head is **not masked**,
135
- - 0 indicates the head is **masked**.
136
-
137
- encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
138
- Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
139
- `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
140
- the output of the last layer of the encoder. Used in the cross-attention of the decoder.
141
- past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
142
- Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
143
-
144
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
145
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
146
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
147
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
148
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
149
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
150
- model's internal embedding lookup matrix.
151
- decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
152
- Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
153
- representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
154
- input (see `past_key_values`). This is useful if you want more control over how to convert
155
- `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
156
-
157
- If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
158
- of `inputs_embeds`.
159
-
160
- use_cache (`bool`, *optional*):
161
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
162
- `past_key_values`).
163
-
164
- output_attentions (`bool`, *optional*):
165
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
166
- tensors for more detail.
167
- output_hidden_states (`bool`, *optional*):
168
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
169
- more detail.
170
- return_dict (`bool`, *optional*):
171
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
172
- """
173
-
174
- MT5_ENCODER_INPUTS_DOCSTRING = r"""
175
- Args:
176
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
177
- Indices of input sequence tokens in the vocabulary. MT5 is a model with relative position embeddings so you
178
- should be able to pad the inputs on both the right and the left.
179
-
180
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
181
- [`PreTrainedTokenizer.__call__`] for detail.
182
-
183
- To know more on how to prepare `input_ids` for pretraining take a look a [MT5 Training](./mt5#training).
184
- attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
185
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
186
-
187
- - 1 for tokens that are **not masked**,
188
- - 0 for tokens that are **masked**.
189
-
190
- [What are attention masks?](../glossary#attention-mask)
191
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
192
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
193
-
194
- - 1 indicates the head is **not masked**,
195
- - 0 indicates the head is **masked**.
196
-
197
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
198
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
199
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
200
- model's internal embedding lookup matrix.
201
- output_attentions (`bool`, *optional*):
202
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
203
- tensors for more detail.
204
- output_hidden_states (`bool`, *optional*):
205
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
206
- more detail.
207
- return_dict (`bool`, *optional*):
208
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
209
- """
210
-
211
- # Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
212
- __HEAD_MASK_WARNING_MSG = """
213
- The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
214
- `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
215
- If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
216
- num_heads)`.
217
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling/model.py DELETED
@@ -1,612 +0,0 @@
1
- import copy
2
- import math
3
- import warnings
4
- from typing import List, Optional, Tuple, Union
5
-
6
- from transformers import MT5PreTrainedModel
7
- from transformers.models.mt5 import MT5Stack
8
- from transformers.modeling_outputs import Seq2SeqModelOutput,Seq2SeqLMOutput, BaseModelOutput
9
- from transformers.utils import (
10
- add_start_docstrings,
11
- add_start_docstrings_to_model_forward,
12
- logging,
13
- replace_return_docstrings,
14
- )
15
-
16
- from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
17
-
18
- import torch
19
- from torch import nn
20
- from torch.nn import CrossEntropyLoss
21
-
22
- from .config import MT5Config
23
- from .docstrings import (
24
- PARALLELIZE_DOCSTRING,
25
- DEPARALLELIZE_DOCSTRING,
26
- __HEAD_MASK_WARNING_MSG,
27
- MT5_START_DOCSTRING,
28
- MT5_INPUTS_DOCSTRING,
29
- )
30
-
31
-
32
- logger = logging.get_logger(__name__)
33
-
34
- _CONFIG_FOR_DOC = "MT5Config"
35
- _CHECKPOINT_FOR_DOC = "mt5-small"
36
-
37
-
38
- class MT5Model(MT5PreTrainedModel):
39
- r"""
40
- Examples:
41
-
42
- ```python
43
- >>> from transformers import MT5Model, AutoTokenizer
44
-
45
- >>> model = MT5Model.from_pretrained("google/mt5-small")
46
- >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
47
- >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
48
- >>> summary = "Weiter Verhandlung in Syrien."
49
- >>> inputs = tokenizer(article, return_tensors="pt")
50
- >>> labels = tokenizer(text_target=summary, return_tensors="pt")
51
-
52
- >>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"])
53
- >>> hidden_states = outputs.last_hidden_state
54
- ```"""
55
-
56
- model_type = "mt5"
57
- config_class = MT5Config
58
- _keys_to_ignore_on_load_missing = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
59
- _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
60
- _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
61
-
62
- # Copied from transformers.models.t5.modeling_t5.T5Model.__init__ with T5->MT5
63
- def __init__(self, config: MT5Config):
64
- super().__init__(config)
65
- self.encoder_embedding = nn.Embedding(config.encoder_vocab_size, config.d_model)
66
- if config.shared_embedding:
67
- self.decoder_embedding = self.encoder_embedding
68
- else:
69
- self.decoder_emebedding = nn.Embedding(config.decoder_vocab_size, config.d_model)
70
-
71
- encoder_config = copy.deepcopy(config)
72
- encoder_config.is_decoder = False
73
- encoder_config.use_cache = False
74
- encoder_config.is_encoder_decoder = False
75
- self.encoder = MT5Stack(encoder_config, self.encoder_embedding)
76
-
77
- decoder_config = copy.deepcopy(config)
78
- decoder_config.is_decoder = True
79
- decoder_config.is_encoder_decoder = False
80
- decoder_config.num_layers = config.num_decoder_layers
81
- self.decoder = MT5Stack(decoder_config, self.decoder_emebedding)
82
-
83
- # Initialize weights and apply final processing
84
- self.post_init()
85
-
86
- # Model parallel
87
- self.model_parallel = False
88
- self.device_map = None
89
-
90
- # Copied from transformers.models.t5.modeling_t5.T5Model.parallelize
91
- def parallelize(self, device_map=None):
92
- warnings.warn(
93
- "`T5Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
94
- " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
95
- " `device_map` but it needs to be a dictionary module_name to device, so for instance {'encoder.block.0':"
96
- " 0, 'encoder.block.1': 1, ...}",
97
- FutureWarning,
98
- )
99
- self.device_map = (
100
- get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
101
- if device_map is None
102
- else device_map
103
- )
104
- assert_device_map(self.device_map, len(self.encoder.block))
105
- self.encoder.parallelize(self.device_map)
106
- self.decoder.parallelize(self.device_map)
107
- self.model_parallel = True
108
-
109
- @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
110
- # Copied from transformers.models.t5.modeling_t5.T5Model.deparallelize
111
- def deparallelize(self):
112
- warnings.warn(
113
- "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
114
- FutureWarning,
115
- )
116
- self.encoder.deparallelize()
117
- self.decoder.deparallelize()
118
- self.encoder = self.encoder.to("cpu")
119
- self.decoder = self.decoder.to("cpu")
120
- self.model_parallel = False
121
- self.device_map = None
122
- torch.cuda.empty_cache()
123
-
124
- # Copied from transformers.models.t5.modeling_t5.T5Model.get_input_embeddings
125
- def get_input_embeddings(self):
126
- return self.encoder_embedding
127
-
128
- # Copied from transformers.models.t5.modeling_t5.T5Model.set_input_embeddings
129
- def set_input_embeddings(self, new_embeddings):
130
- self.encoder_embedding = new_embeddings
131
- self.encoder.set_input_embeddings(new_embeddings)
132
- self.decoder.set_input_embeddings(new_embeddings)
133
-
134
- # Copied from transformers.models.t5.modeling_t5.T5Model.get_encoder
135
- def get_encoder(self):
136
- return self.encoder
137
-
138
- # Copied from transformers.models.t5.modeling_t5.T5Model.get_decoder
139
- def get_decoder(self):
140
- return self.decoder
141
-
142
- # Copied from transformers.models.t5.modeling_t5.T5Model._prune_heads
143
- def _prune_heads(self, heads_to_prune):
144
- """
145
- Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
146
- class PreTrainedModel
147
- """
148
- for layer, heads in heads_to_prune.items():
149
- self.encoder.layer[layer].attention.prune_heads(heads)
150
-
151
- @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
152
- # Copied from transformers.models.t5.modeling_t5.T5Model.forward with T5->MT5, t5->mt5
153
- def forward(
154
- self,
155
- input_ids: Optional[torch.LongTensor] = None,
156
- attention_mask: Optional[torch.FloatTensor] = None,
157
- decoder_input_ids: Optional[torch.LongTensor] = None,
158
- decoder_attention_mask: Optional[torch.BoolTensor] = None,
159
- head_mask: Optional[torch.FloatTensor] = None,
160
- decoder_head_mask: Optional[torch.FloatTensor] = None,
161
- cross_attn_head_mask: Optional[torch.Tensor] = None,
162
- encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
163
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
164
- inputs_embeds: Optional[torch.Tensor] = None,
165
- decoder_inputs_embeds: Optional[torch.Tensor] = None,
166
- use_cache: Optional[bool] = None,
167
- output_attentions: Optional[bool] = None,
168
- output_hidden_states: Optional[bool] = None,
169
- return_dict: Optional[bool] = None,
170
- ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
171
- r"""
172
- Returns:
173
-
174
- Example:
175
-
176
- ```python
177
- >>> from transformers import AutoTokenizer, MT5Model
178
-
179
- >>> tokenizer = AutoTokenizer.from_pretrained("mt5-small")
180
- >>> model = MT5Model.from_pretrained("mt5-small")
181
-
182
- >>> input_ids = tokenizer(
183
- ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
184
- ... ).input_ids # Batch size 1
185
- >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
186
-
187
- >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for MT5Model.
188
- >>> # This is not needed for torch's MT5ForConditionalGeneration as it does this internally using labels arg.
189
- >>> decoder_input_ids = model._shift_right(decoder_input_ids)
190
-
191
- >>> # forward pass
192
- >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
193
- >>> last_hidden_states = outputs.last_hidden_state
194
- ```"""
195
- use_cache = use_cache if use_cache is not None else self.config.use_cache
196
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
197
-
198
- # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
199
- if head_mask is not None and decoder_head_mask is None:
200
- if self.config.num_layers == self.config.num_decoder_layers:
201
- warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
202
- decoder_head_mask = head_mask
203
-
204
- # Encode if needed (training, first prediction pass)
205
- if encoder_outputs is None:
206
- encoder_outputs = self.encoder(
207
- input_ids=input_ids,
208
- attention_mask=attention_mask,
209
- inputs_embeds=inputs_embeds,
210
- head_mask=head_mask,
211
- output_attentions=output_attentions,
212
- output_hidden_states=output_hidden_states,
213
- return_dict=return_dict,
214
- )
215
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
216
- encoder_outputs = BaseModelOutput(
217
- last_hidden_state=encoder_outputs[0],
218
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
219
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
220
- )
221
-
222
- hidden_states = encoder_outputs[0]
223
-
224
- # Set device for model parallelism
225
- if self.model_parallel:
226
- torch.cuda.set_device(self.decoder.first_device)
227
- hidden_states = hidden_states.to(self.decoder.first_device)
228
- if decoder_input_ids is not None:
229
- decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
230
- if attention_mask is not None:
231
- attention_mask = attention_mask.to(self.decoder.first_device)
232
- if decoder_attention_mask is not None:
233
- decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
234
-
235
- # Decode
236
- decoder_outputs = self.decoder(
237
- input_ids=decoder_input_ids,
238
- attention_mask=decoder_attention_mask,
239
- inputs_embeds=decoder_inputs_embeds,
240
- past_key_values=past_key_values,
241
- encoder_hidden_states=hidden_states,
242
- encoder_attention_mask=attention_mask,
243
- head_mask=decoder_head_mask,
244
- cross_attn_head_mask=cross_attn_head_mask,
245
- use_cache=use_cache,
246
- output_attentions=output_attentions,
247
- output_hidden_states=output_hidden_states,
248
- return_dict=return_dict,
249
- )
250
-
251
- if not return_dict:
252
- return decoder_outputs + encoder_outputs
253
-
254
- return Seq2SeqModelOutput(
255
- last_hidden_state=decoder_outputs.last_hidden_state,
256
- past_key_values=decoder_outputs.past_key_values,
257
- decoder_hidden_states=decoder_outputs.hidden_states,
258
- decoder_attentions=decoder_outputs.attentions,
259
- cross_attentions=decoder_outputs.cross_attentions,
260
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
261
- encoder_hidden_states=encoder_outputs.hidden_states,
262
- encoder_attentions=encoder_outputs.attentions,
263
- )
264
-
265
-
266
- @add_start_docstrings("""MT5 Model with a `language modeling` head on top.""", MT5_START_DOCSTRING)
267
- class MT5ForConditionalGeneration(MT5PreTrainedModel):
268
- r"""
269
- Examples:
270
-
271
- ```python
272
- >>> from transformers import MT5ForConditionalGeneration, AutoTokenizer
273
-
274
- >>> model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
275
- >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
276
- >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
277
- >>> summary = "Weiter Verhandlung in Syrien."
278
- >>> inputs = tokenizer(article, text_target=summary, return_tensors="pt")
279
-
280
- >>> outputs = model(**inputs)
281
- >>> loss = outputs.loss
282
- ```"""
283
-
284
- model_type = "mt5"
285
- config_class = MT5Config
286
- _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
287
- _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
288
-
289
- # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.__init__ with T5->MT5
290
- def __init__(self, config: MT5Config):
291
- super().__init__(config)
292
- self.model_dim = config.d_model
293
-
294
- self.encoder_embedding = nn.Embedding(config.encoder_vocab_size, config.d_model)
295
- if config.shared_embedding:
296
- self.decoder_embedding = self.encoder_embedding
297
- else:
298
- self.decoder_emebedding = nn.Embedding(config.decoder_vocab_size, config.d_model)
299
-
300
- encoder_config = copy.deepcopy(config)
301
- encoder_config.is_decoder = False
302
- encoder_config.use_cache = False
303
- encoder_config.is_encoder_decoder = False
304
- self.encoder = MT5Stack(encoder_config, self.encoder_embedding)
305
-
306
- decoder_config = copy.deepcopy(config)
307
- decoder_config.is_decoder = True
308
- decoder_config.is_encoder_decoder = False
309
- decoder_config.num_layers = config.num_decoder_layers
310
- self.decoder = MT5Stack(decoder_config, self.decoder_emebedding)
311
-
312
- self.lm_head = nn.Linear(config.d_model, config.decoder_vocab_size, bias=False)
313
-
314
- # Initialize weights and apply final processing
315
- self.post_init()
316
-
317
- # Model parallel
318
- self.model_parallel = False
319
- self.device_map = None
320
-
321
- @add_start_docstrings(PARALLELIZE_DOCSTRING)
322
- # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.parallelize
323
- def parallelize(self, device_map=None):
324
- warnings.warn(
325
- "`T5ForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you"
326
- " should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also"
327
- " provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance"
328
- " {'encoder.block.0': 0, 'encoder.block.1': 1, ...}",
329
- FutureWarning,
330
- )
331
- self.device_map = (
332
- get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
333
- if device_map is None
334
- else device_map
335
- )
336
- assert_device_map(self.device_map, len(self.encoder.block))
337
- self.encoder.parallelize(self.device_map)
338
- self.decoder.parallelize(self.device_map)
339
- self.lm_head = self.lm_head.to(self.decoder.first_device)
340
- self.model_parallel = True
341
-
342
- @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
343
- # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.deparallelize
344
- def deparallelize(self):
345
- warnings.warn(
346
- "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
347
- FutureWarning,
348
- )
349
- self.encoder.deparallelize()
350
- self.decoder.deparallelize()
351
- self.encoder = self.encoder.to("cpu")
352
- self.decoder = self.decoder.to("cpu")
353
- self.lm_head = self.lm_head.to("cpu")
354
- self.model_parallel = False
355
- self.device_map = None
356
- torch.cuda.empty_cache()
357
-
358
- # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_input_embeddings
359
- def get_input_embeddings(self):
360
- return self.encoder_embedding
361
-
362
- # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.set_input_embeddings
363
- def set_input_embeddings(self, new_embeddings):
364
- self.encoder_embedding = new_embeddings
365
- self.encoder.set_input_embeddings(new_embeddings)
366
- self.decoder.set_input_embeddings(new_embeddings)
367
-
368
- # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.set_output_embeddings
369
- def set_output_embeddings(self, new_embeddings):
370
- self.lm_head = new_embeddings
371
-
372
- # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_output_embeddings
373
- def get_output_embeddings(self):
374
- return self.lm_head
375
-
376
- # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_encoder
377
- def get_encoder(self):
378
- return self.encoder
379
-
380
- # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_decoder
381
- def get_decoder(self):
382
- return self.decoder
383
-
384
- @add_start_docstrings_to_model_forward(MT5_INPUTS_DOCSTRING)
385
- @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
386
- # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.forward with T5->MT5, t5->mt5
387
- def forward(
388
- self,
389
- input_ids: Optional[torch.LongTensor] = None,
390
- attention_mask: Optional[torch.FloatTensor] = None,
391
- decoder_input_ids: Optional[torch.LongTensor] = None,
392
- decoder_attention_mask: Optional[torch.BoolTensor] = None,
393
- head_mask: Optional[torch.FloatTensor] = None,
394
- decoder_head_mask: Optional[torch.FloatTensor] = None,
395
- cross_attn_head_mask: Optional[torch.Tensor] = None,
396
- encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
397
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
398
- inputs_embeds: Optional[torch.FloatTensor] = None,
399
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
400
- labels: Optional[torch.LongTensor] = None,
401
- use_cache: Optional[bool] = None,
402
- output_attentions: Optional[bool] = None,
403
- output_hidden_states: Optional[bool] = None,
404
- return_dict: Optional[bool] = None,
405
- ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
406
- r"""
407
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
408
- Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
409
- config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
410
- labels in `[0, ..., config.vocab_size]`
411
-
412
- Returns:
413
-
414
- Examples:
415
-
416
- ```python
417
- >>> from transformers import AutoTokenizer, MT5ForConditionalGeneration
418
-
419
- >>> tokenizer = AutoTokenizer.from_pretrained("mt5-small")
420
- >>> model = MT5ForConditionalGeneration.from_pretrained("mt5-small")
421
-
422
- >>> # training
423
- >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
424
- >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
425
- >>> outputs = model(input_ids=input_ids, labels=labels)
426
- >>> loss = outputs.loss
427
- >>> logits = outputs.logits
428
-
429
- >>> # inference
430
- >>> input_ids = tokenizer(
431
- ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
432
- ... ).input_ids # Batch size 1
433
- >>> outputs = model.generate(input_ids)
434
- >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
435
- >>> # studies have shown that owning a dog is good for you.
436
- ```"""
437
- use_cache = use_cache if use_cache is not None else self.config.use_cache
438
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
439
-
440
- # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
441
- if head_mask is not None and decoder_head_mask is None:
442
- if self.config.num_layers == self.config.num_decoder_layers:
443
- warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
444
- decoder_head_mask = head_mask
445
-
446
- # Encode if needed (training, first prediction pass)
447
- if encoder_outputs is None:
448
- # Convert encoder inputs in embeddings if needed
449
- encoder_outputs = self.encoder(
450
- input_ids=input_ids,
451
- attention_mask=attention_mask,
452
- inputs_embeds=inputs_embeds,
453
- head_mask=head_mask,
454
- output_attentions=output_attentions,
455
- output_hidden_states=output_hidden_states,
456
- return_dict=return_dict,
457
- )
458
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
459
- encoder_outputs = BaseModelOutput(
460
- last_hidden_state=encoder_outputs[0],
461
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
462
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
463
- )
464
-
465
- hidden_states = encoder_outputs[0]
466
-
467
- if self.model_parallel:
468
- torch.cuda.set_device(self.decoder.first_device)
469
-
470
- if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
471
- # get decoder inputs from shifting lm labels to the right
472
- decoder_input_ids = self._shift_right(labels)
473
-
474
- # Set device for model parallelism
475
- if self.model_parallel:
476
- torch.cuda.set_device(self.decoder.first_device)
477
- hidden_states = hidden_states.to(self.decoder.first_device)
478
- if decoder_input_ids is not None:
479
- decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
480
- if attention_mask is not None:
481
- attention_mask = attention_mask.to(self.decoder.first_device)
482
- if decoder_attention_mask is not None:
483
- decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
484
-
485
- # Decode
486
- decoder_outputs = self.decoder(
487
- input_ids=decoder_input_ids,
488
- attention_mask=decoder_attention_mask,
489
- inputs_embeds=decoder_inputs_embeds,
490
- past_key_values=past_key_values,
491
- encoder_hidden_states=hidden_states,
492
- encoder_attention_mask=attention_mask,
493
- head_mask=decoder_head_mask,
494
- cross_attn_head_mask=cross_attn_head_mask,
495
- use_cache=use_cache,
496
- output_attentions=output_attentions,
497
- output_hidden_states=output_hidden_states,
498
- return_dict=return_dict,
499
- )
500
-
501
- sequence_output = decoder_outputs[0]
502
-
503
- # Set device for model parallelism
504
- if self.model_parallel:
505
- torch.cuda.set_device(self.encoder.first_device)
506
- self.lm_head = self.lm_head.to(self.encoder.first_device)
507
- sequence_output = sequence_output.to(self.lm_head.weight.device)
508
-
509
- if self.config.tie_word_embeddings:
510
- # Rescale output before projecting on vocab
511
- # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
512
- sequence_output = sequence_output * (self.model_dim**-0.5)
513
-
514
- lm_logits = self.lm_head(sequence_output)
515
-
516
- loss = None
517
- if labels is not None:
518
- loss_fct = CrossEntropyLoss(ignore_index=-100)
519
- # move labels to correct device to enable PP
520
- labels = labels.to(lm_logits.device)
521
- loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
522
- # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
523
-
524
- if not return_dict:
525
- output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
526
- return ((loss,) + output) if loss is not None else output
527
-
528
- return Seq2SeqLMOutput(
529
- loss=loss,
530
- logits=lm_logits,
531
- past_key_values=decoder_outputs.past_key_values,
532
- decoder_hidden_states=decoder_outputs.hidden_states,
533
- decoder_attentions=decoder_outputs.attentions,
534
- cross_attentions=decoder_outputs.cross_attentions,
535
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
536
- encoder_hidden_states=encoder_outputs.hidden_states,
537
- encoder_attentions=encoder_outputs.attentions,
538
- )
539
-
540
- # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_inputs_for_generation
541
- def prepare_inputs_for_generation(
542
- self,
543
- input_ids,
544
- past_key_values=None,
545
- attention_mask=None,
546
- head_mask=None,
547
- decoder_head_mask=None,
548
- decoder_attention_mask=None,
549
- cross_attn_head_mask=None,
550
- use_cache=None,
551
- encoder_outputs=None,
552
- **kwargs,
553
- ):
554
- # cut decoder_input_ids if past_key_values is used
555
- if past_key_values is not None:
556
- past_length = past_key_values[0][0].shape[2]
557
-
558
- # Some generation methods already pass only the last input ID
559
- if input_ids.shape[1] > past_length:
560
- remove_prefix_length = past_length
561
- else:
562
- # Default to old behavior: keep only final ID
563
- remove_prefix_length = input_ids.shape[1] - 1
564
-
565
- input_ids = input_ids[:, remove_prefix_length:]
566
-
567
- return {
568
- "decoder_input_ids": input_ids,
569
- "past_key_values": past_key_values,
570
- "encoder_outputs": encoder_outputs,
571
- "attention_mask": attention_mask,
572
- "head_mask": head_mask,
573
- "decoder_head_mask": decoder_head_mask,
574
- "decoder_attention_mask": decoder_attention_mask,
575
- "cross_attn_head_mask": cross_attn_head_mask,
576
- "use_cache": use_cache,
577
- }
578
-
579
- # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_decoder_input_ids_from_labels
580
- def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
581
- return self._shift_right(labels)
582
-
583
- # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration._reorder_cache
584
- def _reorder_cache(self, past_key_values, beam_idx):
585
- # if decoder past is not included in output
586
- # speedy decoding is disabled and no need to reorder
587
- if past_key_values is None:
588
- logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
589
- return past_key_values
590
-
591
- reordered_decoder_past = ()
592
- for layer_past_states in past_key_values:
593
- # get the correct batch idx from layer past batch dim
594
- # batch dim of `past` is at 2nd position
595
- reordered_layer_past_states = ()
596
- for layer_past_state in layer_past_states:
597
- # need to set correct `past` for each of the four key / value states
598
- reordered_layer_past_states = reordered_layer_past_states + (
599
- layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
600
- )
601
-
602
- if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
603
- raise ValueError(
604
- f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
605
- )
606
- if len(reordered_layer_past_states) != len(layer_past_states):
607
- raise ValueError(
608
- f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
609
- )
610
-
611
- reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
612
- return reordered_decoder_past