Automatic Speech Recognition
Transformers
PyTorch
English
joint_aed_ctc_speech-encoder-decoder
custom_code
Eval Results
Lakoc commited on
Commit
e9acf97
1 Parent(s): 2b3d2be

Upload JointCTCAttentionEncoderDecoder

Browse files
auto_wrappers.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+
4
+ from transformers import AutoConfig, AutoModelForCTC, PretrainedConfig
5
+ from transformers.dynamic_module_utils import (
6
+ get_class_from_dynamic_module,
7
+ resolve_trust_remote_code,
8
+ )
9
+ from transformers.models.auto.auto_factory import _get_model_class
10
+
11
+ from .extractors import Conv2dFeatureExtractor
12
+
13
+
14
+ class FeatureExtractionInitModifier(type):
15
+ def __new__(cls, name, bases, dct):
16
+ # Create the class using the original definition
17
+ new_cls = super().__new__(cls, name, bases, dct)
18
+
19
+ # Save the original __init__ method
20
+ original_init = new_cls.__init__
21
+
22
+ # Modify the __init__ method dynamically
23
+ def new_init(self, *args, **kwargs):
24
+ original_init(self, *args, **kwargs)
25
+ if self.config.expect_2d_input:
26
+ getattr(self, self.base_model_prefix).feature_extractor = Conv2dFeatureExtractor(self.config)
27
+
28
+ # Replace the __init__ method with the modified version
29
+ new_cls.__init__ = new_init
30
+
31
+ return new_cls
32
+
33
+
34
+ class CustomAutoModelForCTC(AutoModelForCTC):
35
+ @classmethod
36
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
37
+ config = kwargs.pop("config", None)
38
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
39
+ kwargs["_from_auto"] = True
40
+ hub_kwargs_names = [
41
+ "cache_dir",
42
+ "code_revision",
43
+ "force_download",
44
+ "local_files_only",
45
+ "proxies",
46
+ "resume_download",
47
+ "revision",
48
+ "subfolder",
49
+ "use_auth_token",
50
+ ]
51
+ hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
52
+ if not isinstance(config, PretrainedConfig):
53
+ kwargs_orig = copy.deepcopy(kwargs)
54
+ # ensure not to pollute the config object with torch_dtype="auto" - since it's
55
+ # meaningless in the context of the config object - torch.dtype values are acceptable
56
+ if kwargs.get("torch_dtype", None) == "auto":
57
+ _ = kwargs.pop("torch_dtype")
58
+
59
+ config, kwargs = AutoConfig.from_pretrained(
60
+ pretrained_model_name_or_path,
61
+ return_unused_kwargs=True,
62
+ trust_remote_code=trust_remote_code,
63
+ **hub_kwargs,
64
+ **kwargs,
65
+ )
66
+
67
+ # if torch_dtype=auto was passed here, ensure to pass it on
68
+ if kwargs_orig.get("torch_dtype", None) == "auto":
69
+ kwargs["torch_dtype"] = "auto"
70
+
71
+ has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
72
+ has_local_code = type(config) in cls._model_mapping.keys()
73
+ trust_remote_code = resolve_trust_remote_code(
74
+ trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
75
+ )
76
+ if has_remote_code and trust_remote_code:
77
+ class_ref = config.auto_map[cls.__name__]
78
+ model_class = get_class_from_dynamic_module(
79
+ class_ref, pretrained_model_name_or_path, **hub_kwargs, **kwargs
80
+ )
81
+ model_class = FeatureExtractionInitModifier(model_class.__name__, (model_class,), {})
82
+ _ = hub_kwargs.pop("code_revision", None)
83
+ if os.path.isdir(pretrained_model_name_or_path):
84
+ model_class.register_for_auto_class(cls.__name__)
85
+ else:
86
+ cls.register(config.__class__, model_class, exist_ok=True)
87
+ return model_class.from_pretrained(
88
+ pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
89
+ )
90
+ elif type(config) in cls._model_mapping.keys():
91
+ model_class = _get_model_class(config, cls._model_mapping)
92
+ model_class = FeatureExtractionInitModifier(model_class.__name__, (model_class,), {})
93
+ return model_class.from_pretrained(
94
+ pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
95
+ )
96
+ raise ValueError(
97
+ f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
98
+ f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
99
+ )
100
+
101
+ @classmethod
102
+ def from_config(cls, config, **kwargs):
103
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
104
+ has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
105
+ has_local_code = type(config) in cls._model_mapping.keys()
106
+ trust_remote_code = resolve_trust_remote_code(
107
+ trust_remote_code, config._name_or_path, has_local_code, has_remote_code
108
+ )
109
+
110
+ if has_remote_code and trust_remote_code:
111
+ class_ref = config.auto_map[cls.__name__]
112
+ if "--" in class_ref:
113
+ repo_id, class_ref = class_ref.split("--")
114
+ else:
115
+ repo_id = config.name_or_path
116
+ model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
117
+ if os.path.isdir(config._name_or_path):
118
+ model_class.register_for_auto_class(cls.__name__)
119
+ else:
120
+ cls.register(config.__class__, model_class, exist_ok=True)
121
+ _ = kwargs.pop("code_revision", None)
122
+ model_class = FeatureExtractionInitModifier(model_class.__name__, (model_class,), {})
123
+ return model_class._from_config(config, **kwargs)
124
+ elif type(config) in cls._model_mapping.keys():
125
+ model_class = _get_model_class(config, cls._model_mapping)
126
+ model_class = FeatureExtractionInitModifier(model_class.__name__, (model_class,), {})
127
+ return model_class._from_config(config, **kwargs)
128
+
129
+ raise ValueError(
130
+ f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
131
+ f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
132
+ )
config.json ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "_name_or_path": "/Users/alexanderpolok/PycharmProjects/huggingface_asr/checkpoint-378950",
4
+ "architectures": [
5
+ "JointCTCAttentionEncoderDecoder"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_reguler.JointCTCAttentionEncoderDecoderConfig",
9
+ "AutoModelForSpeechSeq2Seq": "modeling_reguler.JointCTCAttentionEncoderDecoder"
10
+ },
11
+ "ctc_weight": 0.3,
12
+ "decoder": {
13
+ "_name_or_path": "Lakoc/gpt2_512h_8l_add_head6_04",
14
+ "activation_function": "gelu_new",
15
+ "add_cross_attention": true,
16
+ "architectures": null,
17
+ "attn_pdrop": 0.1,
18
+ "average_logits": false,
19
+ "bad_words_ids": null,
20
+ "begin_suppress_tokens": null,
21
+ "bos_token_id": 0,
22
+ "chunk_size_feed_forward": 0,
23
+ "cross_attention_hidden_size": null,
24
+ "decoder_start_token_id": null,
25
+ "diversity_penalty": 0.0,
26
+ "do_sample": false,
27
+ "early_stopping": false,
28
+ "embd_pdrop": 0.1,
29
+ "encoder_no_repeat_ngram_size": 0,
30
+ "eos_token_id": 1,
31
+ "exponential_decay_length_penalty": null,
32
+ "finetuning_task": null,
33
+ "forced_bos_token_id": null,
34
+ "forced_eos_token_id": null,
35
+ "head_locations": [
36
+ 5
37
+ ],
38
+ "head_weights": [
39
+ 0.6,
40
+ 0.4
41
+ ],
42
+ "id2label": {
43
+ "0": "LABEL_0",
44
+ "1": "LABEL_1"
45
+ },
46
+ "initializer_range": 0.02,
47
+ "is_decoder": true,
48
+ "is_encoder_decoder": false,
49
+ "label2id": {
50
+ "LABEL_0": 0,
51
+ "LABEL_1": 1
52
+ },
53
+ "layer_norm_epsilon": 1e-05,
54
+ "length_penalty": 1.0,
55
+ "max_length": 20,
56
+ "min_length": 0,
57
+ "model_type": "gpt2-multi-head",
58
+ "n_embd": 512,
59
+ "n_head": 8,
60
+ "n_inner": 2048,
61
+ "n_layer": 8,
62
+ "n_positions": 1024,
63
+ "no_repeat_ngram_size": 0,
64
+ "num_beam_groups": 1,
65
+ "num_beams": 1,
66
+ "num_return_sequences": 1,
67
+ "output_attentions": false,
68
+ "output_hidden_states": false,
69
+ "output_scores": false,
70
+ "pad_token_id": null,
71
+ "pos_emb_fixed": true,
72
+ "prefix": null,
73
+ "problem_type": null,
74
+ "pruned_heads": {},
75
+ "remove_invalid_values": false,
76
+ "reorder_and_upcast_attn": false,
77
+ "repetition_penalty": 1.0,
78
+ "resid_pdrop": 0.1,
79
+ "return_dict": true,
80
+ "return_dict_in_generate": false,
81
+ "scale_attn_by_inverse_layer_idx": false,
82
+ "scale_attn_weights": true,
83
+ "sep_token_id": null,
84
+ "summary_activation": null,
85
+ "summary_first_dropout": 0.1,
86
+ "summary_proj_to_labels": true,
87
+ "summary_type": "cls_index",
88
+ "summary_use_proj": true,
89
+ "suppress_tokens": null,
90
+ "task_specific_params": null,
91
+ "temperature": 1.0,
92
+ "tf_legacy_loss": false,
93
+ "tie_additional_weights": false,
94
+ "tie_encoder_decoder": false,
95
+ "tie_word_embeddings": false,
96
+ "tokenizer_class": null,
97
+ "top_k": 50,
98
+ "top_p": 1.0,
99
+ "torch_dtype": null,
100
+ "torchscript": false,
101
+ "transformers_version": "4.31.0",
102
+ "typical_p": 1.0,
103
+ "use_bfloat16": false,
104
+ "use_cache": true,
105
+ "vocab_size": 5000
106
+ },
107
+ "decoder_pos_emb_fixed": true,
108
+ "decoder_start_token_id": 0,
109
+ "decoder_vocab_size": 5000,
110
+ "encoder": {
111
+ "_name_or_path": "Lakoc/ebranchformer_16l_512h",
112
+ "activation_dropout": 0.1,
113
+ "adapter_attn_dim": null,
114
+ "adapter_kernel_size": 3,
115
+ "adapter_stride": 2,
116
+ "add_adapter": false,
117
+ "add_cross_attention": false,
118
+ "apply_spec_augment": false,
119
+ "apply_time_warp": false,
120
+ "architectures": null,
121
+ "attention_dropout": 0.1,
122
+ "bad_words_ids": null,
123
+ "begin_suppress_tokens": null,
124
+ "bos_token_id": 1,
125
+ "chunk_size_feed_forward": 0,
126
+ "classifier_proj_size": 256,
127
+ "codevector_dim": 256,
128
+ "conformer_conv_dropout": 0.1,
129
+ "contrastive_logits_temperature": 0.1,
130
+ "conv_bias": false,
131
+ "conv_depthwise_kernel_size": 31,
132
+ "conv_dim": [
133
+ 512,
134
+ 512
135
+ ],
136
+ "conv_kernel": [
137
+ 3,
138
+ 3
139
+ ],
140
+ "conv_stride": [
141
+ 2,
142
+ 2
143
+ ],
144
+ "cross_attention_hidden_size": null,
145
+ "csgu_activation": "identity",
146
+ "csgu_conv_dropout": 0.1,
147
+ "csgu_kernel_size": 31,
148
+ "csgu_use_linear_after_conv": false,
149
+ "ctc_loss_reduction": "mean",
150
+ "ctc_zero_infinity": true,
151
+ "decoder_start_token_id": null,
152
+ "diversity_loss_weight": 0.1,
153
+ "diversity_penalty": 0.0,
154
+ "do_sample": false,
155
+ "do_stable_layer_norm": false,
156
+ "early_stopping": false,
157
+ "encoder_no_repeat_ngram_size": 0,
158
+ "eos_token_id": 2,
159
+ "expect_2d_input": true,
160
+ "exponential_decay_length_penalty": null,
161
+ "fe_position_embeddings": true,
162
+ "feat_extract_activation": "gelu",
163
+ "feat_extract_norm": "group",
164
+ "feat_proj_dropout": 0.0,
165
+ "feat_quantizer_dropout": 0.0,
166
+ "final_dropout": 0.1,
167
+ "finetuning_task": null,
168
+ "forced_bos_token_id": null,
169
+ "forced_eos_token_id": null,
170
+ "hidden_act": "gelu",
171
+ "hidden_dropout": 0.1,
172
+ "hidden_size": 512,
173
+ "id2label": {
174
+ "0": "LABEL_0",
175
+ "1": "LABEL_1"
176
+ },
177
+ "initializer_range": 0.02,
178
+ "intermediate_size": 2048,
179
+ "is_decoder": false,
180
+ "is_encoder_decoder": false,
181
+ "label2id": {
182
+ "LABEL_0": 0,
183
+ "LABEL_1": 1
184
+ },
185
+ "layer_norm_eps": 1e-05,
186
+ "layerdrop": 0.0,
187
+ "length_penalty": 1.0,
188
+ "mask_feature_length": 10,
189
+ "mask_feature_min_masks": 0,
190
+ "mask_feature_prob": 0.0,
191
+ "mask_time_length": 10,
192
+ "mask_time_min_masks": 2,
193
+ "mask_time_prob": 0.05,
194
+ "max_length": 20,
195
+ "max_source_positions": 1024,
196
+ "merge_conv_kernel": 31,
197
+ "min_length": 0,
198
+ "model_type": "wav2vec2-ebranchformer",
199
+ "no_repeat_ngram_size": 0,
200
+ "num_adapter_layers": 3,
201
+ "num_attention_heads": 4,
202
+ "num_beam_groups": 1,
203
+ "num_beams": 1,
204
+ "num_codevector_groups": 2,
205
+ "num_codevectors_per_group": 320,
206
+ "num_conv_pos_embedding_groups": 16,
207
+ "num_conv_pos_embeddings": 128,
208
+ "num_feat_extract_layers": 2,
209
+ "num_hidden_layers": 16,
210
+ "num_mel_bins": 80,
211
+ "num_negatives": 100,
212
+ "num_return_sequences": 1,
213
+ "output_attentions": false,
214
+ "output_hidden_size": 512,
215
+ "output_hidden_states": false,
216
+ "output_scores": false,
217
+ "pad_token_id": 3,
218
+ "position_embeddings_type": "relative",
219
+ "prefix": null,
220
+ "problem_type": null,
221
+ "proj_codevector_dim": 256,
222
+ "pruned_heads": {},
223
+ "remove_invalid_values": false,
224
+ "repetition_penalty": 1.0,
225
+ "return_dict": true,
226
+ "return_dict_in_generate": false,
227
+ "rotary_embedding_base": 10000,
228
+ "second_dim_input_size": 80,
229
+ "sep_token_id": null,
230
+ "suppress_tokens": null,
231
+ "task_specific_params": null,
232
+ "tdnn_dilation": [
233
+ 1,
234
+ 2,
235
+ 3,
236
+ 1,
237
+ 1
238
+ ],
239
+ "tdnn_dim": [
240
+ 512,
241
+ 512,
242
+ 512,
243
+ 512,
244
+ 1500
245
+ ],
246
+ "tdnn_kernel": [
247
+ 5,
248
+ 3,
249
+ 3,
250
+ 1,
251
+ 1
252
+ ],
253
+ "temperature": 1.0,
254
+ "tf_legacy_loss": false,
255
+ "tie_encoder_decoder": false,
256
+ "tie_word_embeddings": true,
257
+ "time_warp_mode": "bicubic",
258
+ "time_warp_window": 5,
259
+ "tokenizer_class": null,
260
+ "top_k": 50,
261
+ "top_p": 1.0,
262
+ "torch_dtype": null,
263
+ "torchscript": false,
264
+ "transformers_version": "4.31.0",
265
+ "typical_p": 1.0,
266
+ "use_bfloat16": false,
267
+ "use_fbanks": true,
268
+ "use_macaron_ff": true,
269
+ "use_weighted_layer_sum": false,
270
+ "vocab_size": 5000,
271
+ "xvector_output_dim": 512
272
+ },
273
+ "encoder_ctc_loss_reduction": "mean",
274
+ "encoder_expect_2d_input": true,
275
+ "encoder_layerdrop": 0.0,
276
+ "encoder_pad_token_id": 3,
277
+ "encoder_second_dim_input_size": 80,
278
+ "encoder_vocab_size": 5000,
279
+ "is_encoder_decoder": true,
280
+ "lsm_factor": 0.1,
281
+ "model_type": "joint_aed_ctc_speech-encoder-decoder",
282
+ "pad_token_id": 3,
283
+ "shared_lm_head": false,
284
+ "tie_word_embeddings": false,
285
+ "tokenizer_class": "PreTrainedTokenizerFast",
286
+ "torch_dtype": "float32",
287
+ "transformers_version": null
288
+ }
configuration_reguler.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModelForCausalLM, SpeechEncoderDecoderConfig
2
+
3
+ from .auto_wrappers import CustomAutoModelForCTC
4
+ from .e_branchformer import Wav2Vec2EBranchformerConfig, Wav2Vec2EBranchformerForCTC
5
+ from .multi_head_gpt2 import GPT2LMMultiHeadModel, GPT2MultiHeadConfig
6
+ from .residual_clasiffier_gpt2 import (
7
+ GPT2ResidualsLMHeadConfig,
8
+ GPT2ResidualsLMHeadModel,
9
+ )
10
+
11
+ AutoConfig.register("gpt2-multi-head", GPT2MultiHeadConfig)
12
+ AutoModelForCausalLM.register(GPT2MultiHeadConfig, GPT2LMMultiHeadModel)
13
+
14
+ AutoConfig.register("gpt2-residuals-head", GPT2ResidualsLMHeadConfig)
15
+ AutoModelForCausalLM.register(GPT2ResidualsLMHeadConfig, GPT2ResidualsLMHeadModel)
16
+
17
+ AutoConfig.register("wav2vec2-ebranchformer", Wav2Vec2EBranchformerConfig)
18
+ CustomAutoModelForCTC.register(Wav2Vec2EBranchformerConfig, Wav2Vec2EBranchformerForCTC)
19
+
20
+
21
+ class JointCTCAttentionEncoderDecoderConfig(SpeechEncoderDecoderConfig):
22
+ model_type = "joint_aed_ctc_speech-encoder-decoder"
23
+ is_composition = True
ctc_scorer.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # Copied from: https://github.com/espnet/espnet/blob/master/espnet/nets/ctc_prefix_score.py
3
+ import torch
4
+ from transformers import GenerationConfig, LogitsProcessor
5
+
6
+
7
+ class GenerationConfigWithCTC(GenerationConfig):
8
+ def __init__(self, ctc_weight=0.0, ctc_margin=0, **kwargs):
9
+ super().__init__(**kwargs)
10
+ self.ctc_weight = ctc_weight
11
+ self.ctc_margin = ctc_margin
12
+
13
+
14
+ class CTCPrefixScoreTH(object):
15
+ """Batch processing of CTCPrefixScore
16
+
17
+ which is based on Algorithm 2 in WATANABE et al.
18
+ "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
19
+ but extended to efficiently compute the label probablities for multiple
20
+ hypotheses simultaneously
21
+ See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based
22
+ Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019.
23
+ """
24
+
25
+ def __init__(self, x, xlens, blank, eos, margin=0):
26
+ """Construct CTC prefix scorer
27
+
28
+ :param torch.Tensor x: input label posterior sequences (B, T, O)
29
+ :param torch.Tensor xlens: input lengths (B,)
30
+ :param int blank: blank label id
31
+ :param int eos: end-of-sequence id
32
+ :param int margin: margin parameter for windowing (0 means no windowing)
33
+ """
34
+ # In the comment lines,
35
+ # we assume T: input_length, B: batch size, W: beam width, O: output dim.
36
+ self.logzero = -10000000000.0
37
+ self.blank = blank
38
+ self.eos = eos
39
+ self.batch = x.size(0)
40
+ self.input_length = x.size(1)
41
+ self.odim = x.size(2)
42
+ self.dtype = x.dtype
43
+ self.device = torch.device("cuda:%d" % x.get_device()) if x.is_cuda else torch.device("cpu")
44
+ # Pad the rest of posteriors in the batch
45
+ # TODO(takaaki-hori): need a better way without for-loops
46
+ for i, l in enumerate(xlens):
47
+ if l < self.input_length:
48
+ x[i, l:, :] = self.logzero
49
+ x[i, l:, blank] = 0
50
+ # Reshape input x
51
+ xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
52
+ xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
53
+ self.x = torch.stack([xn, xb]) # (2, T, B, O)
54
+ self.end_frames = torch.as_tensor(xlens) - 1
55
+
56
+ # Setup CTC windowing
57
+ self.margin = margin
58
+ if margin > 0:
59
+ self.frame_ids = torch.arange(self.input_length, dtype=self.dtype, device=self.device)
60
+ # Base indices for index conversion
61
+ self.idx_bh = None
62
+ self.idx_b = torch.arange(self.batch, device=self.device)
63
+ self.idx_bo = (self.idx_b * self.odim).unsqueeze(1)
64
+
65
+ def __call__(self, y, state, scoring_ids=None, att_w=None):
66
+ """Compute CTC prefix scores for next labels
67
+
68
+ :param list y: prefix label sequences
69
+ :param tuple state: previous CTC state
70
+ :param torch.Tensor att_w: attention weights to decide CTC window
71
+ :return new_state, ctc_local_scores (BW, O)
72
+ """
73
+
74
+ # print(self.tokenizer.batch_decode(y))
75
+ output_length = len(y[0]) - 1 # ignore sos
76
+ last_ids = [yi[-1] for yi in y] # last output label ids
77
+ n_bh = len(last_ids) # batch * hyps
78
+ n_hyps = n_bh // self.batch # assuming each utterance has the same # of hyps
79
+ self.scoring_num = scoring_ids.size(-1) if scoring_ids is not None else 0
80
+ # prepare state info
81
+ if state is None:
82
+ r_prev = torch.full(
83
+ (self.input_length, 2, self.batch, n_hyps),
84
+ self.logzero,
85
+ dtype=self.dtype,
86
+ device=self.device,
87
+ )
88
+ r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2)
89
+ r_prev = r_prev.view(-1, 2, n_bh)
90
+ s_prev = 0.0
91
+ f_min_prev = 0
92
+ f_max_prev = 1
93
+ else:
94
+ r_prev, s_prev, f_min_prev, f_max_prev = state
95
+
96
+ # select input dimensions for scoring
97
+ if self.scoring_num > 0:
98
+ scoring_idmap = torch.full((n_bh, self.odim), -1, dtype=torch.long, device=self.device)
99
+ snum = self.scoring_num
100
+ if self.idx_bh is None or n_bh > len(self.idx_bh):
101
+ self.idx_bh = torch.arange(n_bh, device=self.device).view(-1, 1)
102
+ scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = torch.arange(snum, device=self.device)
103
+ scoring_idx = (scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1)).view(-1)
104
+ x_ = torch.index_select(self.x.view(2, -1, self.batch * self.odim), 2, scoring_idx).view(2, -1, n_bh, snum)
105
+ else:
106
+ scoring_ids = None
107
+ scoring_idmap = None
108
+ snum = self.odim
109
+ x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1, n_bh, snum)
110
+
111
+ # new CTC forward probs are prepared as a (T x 2 x BW x S) tensor
112
+ # that corresponds to r_t^n(h) and r_t^b(h) in a batch.
113
+ r = torch.full(
114
+ (self.input_length, 2, n_bh, snum),
115
+ self.logzero,
116
+ dtype=self.dtype,
117
+ device=self.device,
118
+ )
119
+ if output_length == 0:
120
+ r[0, 0] = x_[0, 0]
121
+
122
+ r_sum = torch.logsumexp(r_prev, 1)
123
+ log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum)
124
+ if scoring_ids is not None:
125
+ for idx in range(n_bh):
126
+ pos = scoring_idmap[idx, last_ids[idx]]
127
+ if pos >= 0:
128
+ log_phi[:, idx, pos] = r_prev[:, 1, idx]
129
+ else:
130
+ for idx in range(n_bh):
131
+ log_phi[:, idx, last_ids[idx]] = r_prev[:, 1, idx]
132
+
133
+ # decide start and end frames based on attention weights
134
+ if att_w is not None and self.margin > 0:
135
+ f_arg = torch.matmul(att_w, self.frame_ids)
136
+ f_min = max(int(f_arg.min().cpu()), f_min_prev)
137
+ f_max = max(int(f_arg.max().cpu()), f_max_prev)
138
+ start = min(f_max_prev, max(f_min - self.margin, output_length, 1))
139
+ end = min(f_max + self.margin, self.input_length)
140
+ else:
141
+ f_min = f_max = 0
142
+ start = max(output_length, 1)
143
+ end = self.input_length
144
+
145
+ if start > end:
146
+ return torch.full_like(s_prev, self.logzero), (
147
+ r,
148
+ torch.full_like(s_prev, self.logzero),
149
+ f_min,
150
+ f_max,
151
+ scoring_idmap,
152
+ )
153
+
154
+ # compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
155
+ for t in range(start, end):
156
+ rp = r[t - 1]
157
+ rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(2, 2, n_bh, snum)
158
+ r[t] = torch.logsumexp(rr, 1) + x_[:, t]
159
+
160
+ # compute log prefix probabilities log(psi)
161
+ log_phi_x = torch.cat((log_phi[0].unsqueeze(0), log_phi[:-1]), dim=0) + x_[0]
162
+ if scoring_ids is not None:
163
+ log_psi = torch.full((n_bh, self.odim), self.logzero, dtype=self.dtype, device=self.device)
164
+ log_psi_ = torch.logsumexp(
165
+ torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
166
+ dim=0,
167
+ )
168
+ for si in range(n_bh):
169
+ log_psi[si, scoring_ids[si]] = log_psi_[si]
170
+ else:
171
+ log_psi = torch.logsumexp(
172
+ torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
173
+ dim=0,
174
+ )
175
+
176
+ for si in range(n_bh):
177
+ log_psi[si, self.eos] = max(log_psi[si, self.eos], r_sum[self.end_frames[si // n_hyps], si])
178
+
179
+ # exclude blank probs
180
+ log_psi[:, self.blank] = self.logzero
181
+
182
+ token_scores = log_psi - s_prev
183
+ token_scores[token_scores == 0] = self.logzero
184
+
185
+ return token_scores, (r, log_psi, f_min, f_max, scoring_idmap)
186
+
187
+ def index_select_state(self, state, best_ids):
188
+ """Select CTC states according to best ids
189
+
190
+ :param state : CTC state
191
+ :param best_ids : index numbers selected by beam pruning (B, W)
192
+ :return selected_state
193
+ """
194
+ r, s, f_min, f_max, scoring_idmap = state
195
+ # convert ids to BHO space
196
+ n_bh = len(s)
197
+ n_hyps = n_bh // self.batch
198
+ vidx = (best_ids + (self.idx_b * (n_hyps * self.odim)).view(-1, 1)).view(-1)
199
+ # select hypothesis scores
200
+ s_new = torch.index_select(s.view(-1), 0, vidx)
201
+ s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim)
202
+ # convert ids to BHS space (S: scoring_num)
203
+ if scoring_idmap is not None:
204
+ snum = self.scoring_num
205
+ hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view(-1)
206
+ label_ids = torch.fmod(best_ids, self.odim).view(-1)
207
+ score_idx = scoring_idmap[hyp_idx, label_ids]
208
+ score_idx[score_idx == -1] = 0
209
+ vidx = score_idx + hyp_idx * snum
210
+ else:
211
+ snum = self.odim
212
+ # select forward probabilities
213
+ r_new = torch.index_select(r.view(-1, 2, n_bh * snum), 2, vidx).view(-1, 2, n_bh)
214
+ return r_new, s_new, f_min, f_max
215
+
216
+ def extend_prob(self, x):
217
+ """Extend CTC prob.
218
+
219
+ :param torch.Tensor x: input label posterior sequences (B, T, O)
220
+ """
221
+
222
+ if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O)
223
+ # Pad the rest of posteriors in the batch
224
+ # TODO(takaaki-hori): need a better way without for-loops
225
+ xlens = [x.size(1)]
226
+ for i, l in enumerate(xlens):
227
+ if l < self.input_length:
228
+ x[i, l:, :] = self.logzero
229
+ x[i, l:, self.blank] = 0
230
+ tmp_x = self.x
231
+ xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
232
+ xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
233
+ self.x = torch.stack([xn, xb]) # (2, T, B, O)
234
+ self.x[:, : tmp_x.shape[1], :, :] = tmp_x
235
+ self.input_length = x.size(1)
236
+ self.end_frames = torch.as_tensor(xlens) - 1
237
+
238
+ def extend_state(self, state):
239
+ """Compute CTC prefix state.
240
+
241
+
242
+ :param state : CTC state
243
+ :return ctc_state
244
+ """
245
+
246
+ if state is None:
247
+ # nothing to do
248
+ return state
249
+ else:
250
+ r_prev, s_prev, f_min_prev, f_max_prev = state
251
+
252
+ r_prev_new = torch.full(
253
+ (self.input_length, 2),
254
+ self.logzero,
255
+ dtype=self.dtype,
256
+ device=self.device,
257
+ )
258
+ start = max(r_prev.shape[0], 1)
259
+ r_prev_new[0:start] = r_prev
260
+ for t in range(start, self.input_length):
261
+ r_prev_new[t, 1] = r_prev_new[t - 1, 1] + self.x[0, t, :, self.blank]
262
+
263
+ return (r_prev_new, s_prev, f_min_prev, f_max_prev)
264
+
265
+
266
+ class CTCRescorerLogitsProcessor(LogitsProcessor):
267
+ def __init__(
268
+ self,
269
+ encoder_logits: torch.FloatTensor,
270
+ encoder_output_lens: torch.LongTensor,
271
+ pad_token_id: int,
272
+ eos_token_id: int,
273
+ ctc_margin: int,
274
+ ctc_weight: float,
275
+ num_beams: int,
276
+ ):
277
+ super().__init__()
278
+ self.pad_token_id = pad_token_id
279
+ self.ctc_prefix_scorer = CTCPrefixScoreTH(
280
+ torch.nn.functional.log_softmax(encoder_logits, dim=-1),
281
+ encoder_output_lens,
282
+ pad_token_id,
283
+ eos_token_id,
284
+ ctc_margin,
285
+ )
286
+ self.ctc_weight = ctc_weight
287
+ self.ctc_states = None
288
+ self.num_beams = num_beams
289
+
290
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
291
+ scores[:, self.pad_token_id] = self.ctc_prefix_scorer.logzero
292
+ if self.ctc_states is not None:
293
+ self.ctc_states = self.ctc_prefix_scorer.index_select_state(
294
+ self.ctc_states, input_ids[:, -1].reshape(-1, self.num_beams)
295
+ )
296
+ ctc_scores, ctc_states = self.ctc_prefix_scorer(input_ids, self.ctc_states)
297
+ self.ctc_states = ctc_states
298
+ next_token_scores = (1 - self.ctc_weight) * scores + self.ctc_weight * ctc_scores
299
+ # return scores
300
+ return next_token_scores
301
+
302
+
303
+ class LogSoftmaxProcessor(LogitsProcessor):
304
+ def __init__(
305
+ self,
306
+ ):
307
+ super().__init__()
308
+
309
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
310
+ scores = torch.nn.functional.log_softmax(scores, dim=-1)
311
+ return scores
e_branchformer.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch Wav2Vec2-Ebranchformer model."""
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.utils.checkpoint
7
+ from torch import nn
8
+ from transformers.activations import ACT2FN
9
+ from transformers.models.wav2vec2.modeling_wav2vec2 import (
10
+ Wav2Vec2Config,
11
+ Wav2Vec2ForCTC,
12
+ Wav2Vec2ForPreTraining,
13
+ )
14
+ from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
15
+ Wav2Vec2ConformerConfig,
16
+ Wav2Vec2ConformerEncoder,
17
+ )
18
+ from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
19
+ Wav2Vec2ConformerFeedForward as Wav2Vec2EBranchformerFeedForward,
20
+ )
21
+ from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
22
+ Wav2Vec2ConformerModel,
23
+ )
24
+ from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
25
+ Wav2Vec2ConformerSelfAttention as Wav2Vec2EBranchformerSelfAttention,
26
+ )
27
+ from transformers.utils import logging
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class Wav2Vec2EBranchformerConfig(Wav2Vec2ConformerConfig, Wav2Vec2Config):
33
+ """Config for EBranhformer model extending conformer."""
34
+
35
+ model_type = "wav2vec2-ebranchformer"
36
+
37
+ def __init__(
38
+ self,
39
+ ebranchformer_conv_dropout=0.1,
40
+ csgu_activation="identity",
41
+ csgu_kernel_size=31,
42
+ csgu_use_linear_after_conv=False,
43
+ merge_conv_kernel=31,
44
+ use_macaron_ff=True,
45
+ **kwargs,
46
+ ):
47
+ super().__init__(**kwargs)
48
+ # EBranchformer related params
49
+ self.csgu_kernel_size = csgu_kernel_size
50
+ self.csgu_activation = csgu_activation
51
+ self.csgu_conv_dropout = ebranchformer_conv_dropout
52
+ self.csgu_use_linear_after_conv = csgu_use_linear_after_conv
53
+ self.merge_conv_kernel = merge_conv_kernel
54
+ self.use_macaron_ff = use_macaron_ff
55
+
56
+
57
+ class ConvolutionalSpatialGatingUnit(torch.nn.Module):
58
+ """Convolutional Spatial Gating Unit (CSGU)."""
59
+
60
+ def __init__(self, config: Wav2Vec2EBranchformerConfig):
61
+ super().__init__()
62
+
63
+ n_channels = config.intermediate_size // 2 # split input channels
64
+ self.norm = torch.nn.LayerNorm(n_channels)
65
+ self.conv = torch.nn.Conv1d(
66
+ n_channels,
67
+ n_channels,
68
+ config.csgu_kernel_size,
69
+ 1,
70
+ (config.csgu_kernel_size - 1) // 2,
71
+ groups=n_channels,
72
+ )
73
+ if config.csgu_use_linear_after_conv:
74
+ self.linear = torch.nn.Linear(n_channels, n_channels)
75
+ else:
76
+ self.linear = None
77
+
78
+ if config.csgu_activation == "identity":
79
+ self.act = torch.nn.Identity()
80
+ else:
81
+ self.act = ACT2FN[config.csgu_activation]
82
+
83
+ self.dropout = torch.nn.Dropout(config.csgu_conv_dropout)
84
+
85
+ def forward(self, hidden_states: torch.FloatTensor):
86
+ """Forward method
87
+
88
+ Args:
89
+ hidden_states (torch.Tensor): (N, T, D)
90
+
91
+ Returns:
92
+ out (torch.Tensor): (N, T, D/2)
93
+ """
94
+
95
+ x_r, x_g = hidden_states.chunk(2, dim=-1)
96
+
97
+ x_g = self.norm(x_g) # (N, T, D/2)
98
+ x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2) # (N, T, D/2)
99
+ if self.linear is not None:
100
+ x_g = self.linear(x_g)
101
+
102
+ x_g = self.act(x_g)
103
+ hidden_states = x_r * x_g # (N, T, D/2)
104
+ hidden_states = self.dropout(hidden_states)
105
+ return hidden_states
106
+
107
+
108
+ class ConvolutionalGatingMLP(torch.nn.Module):
109
+ """Convolutional Gating MLP (cgMLP)."""
110
+
111
+ def __init__(self, config: Wav2Vec2EBranchformerConfig):
112
+ super().__init__()
113
+ self.channel_proj1 = torch.nn.Sequential(
114
+ torch.nn.Linear(config.hidden_size, config.intermediate_size), torch.nn.GELU()
115
+ )
116
+ self.csgu = ConvolutionalSpatialGatingUnit(config)
117
+ self.channel_proj2 = torch.nn.Linear(config.intermediate_size // 2, config.hidden_size)
118
+
119
+ def forward(self, hidden_states: torch.FloatTensor):
120
+ hidden_states = self.channel_proj1(hidden_states) # hidden_size -> intermediate_size
121
+ hidden_states = self.csgu(hidden_states) # intermediate_size -> intermediate_size/2
122
+ hidden_states = self.channel_proj2(hidden_states) # intermediate_size/2 -> hidden_size
123
+ return hidden_states
124
+
125
+
126
+ class Wav2Vec2EBranchformerEncoderLayer(nn.Module):
127
+ def __init__(self, config: Wav2Vec2EBranchformerConfig):
128
+ super().__init__()
129
+ embed_dim = config.hidden_size
130
+ dropout = config.attention_dropout
131
+
132
+ # Feed-forward 1
133
+ if config.use_macaron_ff:
134
+ self.ff1 = nn.Sequential(nn.LayerNorm(embed_dim), Wav2Vec2EBranchformerFeedForward(config))
135
+
136
+ # Self-Attention
137
+ self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
138
+ self.self_attn_dropout = torch.nn.Dropout(dropout)
139
+ self.self_attn = Wav2Vec2EBranchformerSelfAttention(config)
140
+
141
+ # cgMLP
142
+ self.cgMLP = ConvolutionalGatingMLP(config)
143
+ self.cgMLP_layer_norm = nn.LayerNorm(config.hidden_size)
144
+ self.cgMLP_dropout = torch.nn.Dropout(dropout)
145
+
146
+ # Merge
147
+ self.final_dropout = torch.nn.Dropout(dropout)
148
+ self.merge_proj = torch.nn.Linear(embed_dim + embed_dim, embed_dim)
149
+ self.depthwise_conv_fusion = torch.nn.Conv1d(
150
+ embed_dim + embed_dim,
151
+ embed_dim + embed_dim,
152
+ kernel_size=config.merge_conv_kernel,
153
+ stride=1,
154
+ padding=(config.merge_conv_kernel - 1) // 2,
155
+ groups=embed_dim + embed_dim,
156
+ bias=True,
157
+ )
158
+ self.final_layer_norm = nn.LayerNorm(embed_dim)
159
+
160
+ # Feed-forward 2
161
+ if config.use_macaron_ff:
162
+ self.ff2 = nn.Sequential(nn.LayerNorm(embed_dim), Wav2Vec2EBranchformerFeedForward(config))
163
+
164
+ def forward(
165
+ self,
166
+ hidden_states: torch.FloatTensor,
167
+ attention_mask: Optional[torch.Tensor] = None,
168
+ relative_position_embeddings: Optional[torch.Tensor] = None,
169
+ output_attentions: bool = False,
170
+ ):
171
+ # 1. Optional ff1
172
+ if self.ff1:
173
+ residual = hidden_states
174
+ hidden_states = residual + 0.5 * self.ff1(hidden_states)
175
+
176
+ # 2. Split input to three branches
177
+ residual = hidden_states
178
+ global_branch = hidden_states
179
+ local_branch = hidden_states
180
+
181
+ # 3. Self-Attention branch
182
+ global_branch = self.self_attn_layer_norm(global_branch)
183
+ global_branch, attn_weigts = self.self_attn(
184
+ hidden_states=global_branch,
185
+ attention_mask=attention_mask,
186
+ relative_position_embeddings=relative_position_embeddings,
187
+ output_attentions=output_attentions,
188
+ )
189
+ global_branch = self.self_attn_dropout(global_branch)
190
+
191
+ # 4. cgMLP Branch
192
+ local_branch = self.cgMLP_layer_norm(local_branch)
193
+ local_branch = self.cgMLP(local_branch)
194
+
195
+ # 5. Merge operator
196
+ # a, concat
197
+ hidden_states = torch.cat([global_branch, local_branch], dim=-1)
198
+ merge_residual = hidden_states
199
+ # b, depth-wise conv mixing
200
+ hidden_states = merge_residual + self.depthwise_conv_fusion(hidden_states.transpose(1, 2)).transpose(1, 2)
201
+ # c, project back to original size and final dropout
202
+ hidden_states = self.final_dropout(self.merge_proj(hidden_states))
203
+
204
+ # 6. Add residual
205
+ hidden_states = residual + hidden_states
206
+
207
+ # 7. Optional ff2
208
+ if self.ff2:
209
+ residual = hidden_states
210
+ hidden_states = residual + 0.5 * self.ff2(hidden_states)
211
+
212
+ # 8. Final layer norm
213
+ hidden_states = self.final_layer_norm(hidden_states)
214
+ return hidden_states, attn_weigts
215
+
216
+
217
+ class Wav2Vec2EBranchformerEncoder(Wav2Vec2ConformerEncoder):
218
+ def __init__(self, config: Wav2Vec2EBranchformerConfig):
219
+ super().__init__(config)
220
+ self.layers = nn.ModuleList(
221
+ [Wav2Vec2EBranchformerEncoderLayer(config) for _ in range(config.num_hidden_layers)]
222
+ )
223
+ self.pos_conv_embed = None
224
+
225
+
226
+ class Wav2Vec2EBranchformerModel(Wav2Vec2ConformerModel):
227
+ def __init__(self, config: Wav2Vec2EBranchformerConfig):
228
+ super().__init__(config)
229
+ self.encoder = Wav2Vec2EBranchformerEncoder(config)
230
+
231
+ # Initialize weights and apply final processing
232
+ self.post_init()
233
+
234
+
235
+ class Wav2Vec2EBranchformerForPreTraining(Wav2Vec2ForPreTraining):
236
+ config_class = Wav2Vec2EBranchformerConfig
237
+ base_model_prefix = "wav2vec2"
238
+
239
+ def __init__(self, config: Wav2Vec2EBranchformerConfig):
240
+ super().__init__(config)
241
+ self.wav2vec2 = Wav2Vec2EBranchformerModel(config)
242
+ self.post_init()
243
+
244
+
245
+ class Wav2Vec2EBranchformerForCTC(Wav2Vec2ForCTC):
246
+ config_class = Wav2Vec2EBranchformerConfig
247
+ base_model_prefix = "wav2vec2"
248
+
249
+ def __init__(self, config: Wav2Vec2EBranchformerConfig):
250
+ super().__init__(config)
251
+ self.wav2vec2 = Wav2Vec2EBranchformerModel(config)
252
+ self.post_init()
embeddings.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class AdaptiveEmbedding(nn.Module):
6
+ def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, sample_softmax=False):
7
+ super().__init__()
8
+
9
+ self.n_token = n_token
10
+ self.d_embed = d_embed
11
+
12
+ self.cutoffs = cutoffs + [n_token]
13
+ self.div_val = div_val
14
+ self.d_proj = d_proj
15
+
16
+ self.emb_scale = d_proj**0.5
17
+
18
+ self.cutoff_ends = [0] + self.cutoffs
19
+
20
+ self.emb_layers = nn.ModuleList()
21
+ self.emb_projs = nn.ParameterList()
22
+ if div_val == 1:
23
+ self.emb_layers.append(nn.Embedding(n_token, d_embed, sparse=sample_softmax > 0))
24
+ if d_proj != d_embed:
25
+ self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed)))
26
+ else:
27
+ for i in range(len(self.cutoffs)):
28
+ l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
29
+ d_emb_i = d_embed // (div_val**i)
30
+ self.emb_layers.append(nn.Embedding(r_idx - l_idx, d_emb_i))
31
+ self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i)))
32
+
33
+ def forward(self, inp):
34
+ if self.div_val == 1:
35
+ embed = self.emb_layers[0](inp)
36
+ if self.d_proj != self.d_embed:
37
+ embed = nn.functional.linear(embed, self.emb_projs[0])
38
+ else:
39
+ param = next(self.parameters())
40
+ inp_flat = inp.view(-1)
41
+ emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], dtype=param.dtype, device=param.device)
42
+ for i in range(len(self.cutoffs)):
43
+ l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
44
+
45
+ mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
46
+ indices_i = mask_i.nonzero().squeeze()
47
+
48
+ if indices_i.numel() == 0:
49
+ continue
50
+
51
+ inp_i = inp_flat.index_select(0, indices_i) - l_idx
52
+ emb_i = self.emb_layers[i](inp_i)
53
+ emb_i = nn.functional.linear(emb_i, self.emb_projs[i])
54
+
55
+ emb_flat.index_copy_(0, indices_i, emb_i)
56
+
57
+ embed_shape = inp.size() + (self.d_proj,)
58
+ embed = emb_flat.view(embed_shape)
59
+
60
+ embed.mul_(self.emb_scale)
61
+
62
+ return embed
63
+
64
+
65
+ class PositionalEmbeddingAux(nn.Module):
66
+ def __init__(self, demb):
67
+ super().__init__()
68
+
69
+ self.demb = demb
70
+
71
+ inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
72
+ self.register_buffer("inv_freq", inv_freq)
73
+
74
+ def forward(self, pos_seq, bsz=None):
75
+ sinusoid_inp = torch.outer(pos_seq, self.inv_freq)
76
+ pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
77
+
78
+ if bsz is not None:
79
+ return pos_emb[:, None, :].expand(-1, bsz, -1)
80
+ else:
81
+ return pos_emb[:, None, :]
82
+
83
+
84
+ class PositionalEmbedding(PositionalEmbeddingAux):
85
+ def forward(self, pos_seq, bsz=None):
86
+ return super().forward(pos_seq.squeeze(0), bsz=bsz).squeeze(1)
extractors.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers.activations import ACT2FN
4
+
5
+
6
+ class Conv2dFeatureExtractor(nn.Module):
7
+ def __init__(self, config):
8
+ super().__init__()
9
+ self.conv = torch.nn.Sequential(
10
+ *[
11
+ nn.Sequential(
12
+ nn.Conv2d(
13
+ conv_in,
14
+ out_channels=conv_out,
15
+ kernel_size=(conv_kernel, conv_kernel),
16
+ stride=(conv_stride, conv_stride),
17
+ ),
18
+ ACT2FN[config.feat_extract_activation],
19
+ )
20
+ for conv_in, conv_out, conv_kernel, conv_stride in zip(
21
+ [1, *config.conv_dim], config.conv_dim, config.conv_kernel, config.conv_stride
22
+ )
23
+ ],
24
+ )
25
+
26
+ linear_in_dim = config.conv_dim[-1] * (((config.second_dim_input_size - 1) // 2 - 1) // 2)
27
+ self.out = torch.nn.Linear(linear_in_dim, config.hidden_size, bias=True)
28
+
29
+ def forward(self, input_values: torch.Tensor) -> torch.Tensor:
30
+ hidden_states = self.conv(input_values[:, None, ...])
31
+ hidden_states = self.out(hidden_states.transpose(1, 2).flatten(2, 3))
32
+ return hidden_states.transpose(1, 2)
generation_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 0,
3
+ "decoder_start_token_id": 0,
4
+ "eos_token_id": 1,
5
+ "max_length": 512,
6
+ "pad_token_id": 3,
7
+ "transformers_version": "4.31.0"
8
+ }
modeling_reguler.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import CrossEntropyLoss
7
+ from transformers import (
8
+ AutoConfig,
9
+ AutoModelForCausalLM,
10
+ AutoModelForSpeechSeq2Seq,
11
+ GenerationConfig,
12
+ PretrainedConfig,
13
+ PreTrainedModel,
14
+ SpeechEncoderDecoderConfig,
15
+ SpeechEncoderDecoderModel,
16
+ StoppingCriteriaList,
17
+ )
18
+ from transformers.generation.logits_process import LogitsProcessorList
19
+ from transformers.generation.utils import GenerateOutput
20
+ from transformers.modeling_outputs import CausalLMOutput, Seq2SeqLMOutput
21
+ from transformers.models.speech_encoder_decoder.modeling_speech_encoder_decoder import (
22
+ shift_tokens_right,
23
+ )
24
+ from transformers.utils import logging
25
+
26
+ from .auto_wrappers import CustomAutoModelForCTC
27
+ from .configuration_reguler import JointCTCAttentionEncoderDecoderConfig
28
+ from .ctc_scorer import (
29
+ CTCRescorerLogitsProcessor,
30
+ GenerationConfigWithCTC,
31
+ LogSoftmaxProcessor,
32
+ )
33
+ from .embeddings import AdaptiveEmbedding, PositionalEmbedding
34
+ from .multi_head_gpt2 import GPT2LMMultiHeadModel
35
+
36
+ logger = logging.get_logger("transformers")
37
+
38
+
39
+ def wav2vec2_forward_hidden_return_hook(_: PreTrainedModel, __: Any, kwargs):
40
+ kwargs["output_hidden_states"] = True
41
+
42
+
43
+ @dataclass
44
+ class Seq2SeqLMOutputLosses(Seq2SeqLMOutput):
45
+ enc_loss: Optional[torch.FloatTensor] = None
46
+ dec_loss: Optional[torch.FloatTensor] = None
47
+ encoder_logits: Optional[torch.FloatTensor] = None
48
+
49
+
50
+ def wav2vec2_for_ctc_forward_hook(model: CustomAutoModelForCTC, input: Any, output: CausalLMOutput):
51
+ if "hidden_states" in output:
52
+ output.last_hidden_state = output.hidden_states[-1]
53
+
54
+
55
+ class JointCTCAttentionEncoderDecoder(SpeechEncoderDecoderModel):
56
+ """Custom model for CTC+Attention loss based on the ESPNet architecture"""
57
+
58
+ config_class = JointCTCAttentionEncoderDecoderConfig
59
+ base_model_prefix = "joint_aed_ctc_speech-encoder-decoder"
60
+
61
+ def __init__(
62
+ self,
63
+ config: Optional[PretrainedConfig] = None,
64
+ encoder: Optional[PreTrainedModel] = None,
65
+ decoder: Optional[PreTrainedModel] = None,
66
+ ):
67
+ if config is None and (encoder is None or decoder is None):
68
+ raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
69
+ if config is None:
70
+ config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
71
+ else:
72
+ if not isinstance(config, self.config_class):
73
+ raise ValueError(f"Config: {config} has to be of type {self.config_class}")
74
+
75
+ if config.decoder.cross_attention_hidden_size is not None:
76
+ if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
77
+ raise ValueError(
78
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
79
+ f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
80
+ f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
81
+ " `config.encoder.hidden_size`."
82
+ )
83
+
84
+ # initialize with config
85
+ # make sure input & output embeddings is not tied
86
+ config.tie_word_embeddings = False
87
+ super(SpeechEncoderDecoderModel, self).__init__(config)
88
+
89
+ if encoder is None:
90
+ encoder = CustomAutoModelForCTC.from_config(config.encoder)
91
+ encoder.register_forward_hook(wav2vec2_for_ctc_forward_hook)
92
+ encoder.register_forward_pre_hook(wav2vec2_forward_hidden_return_hook, with_kwargs=True)
93
+ if decoder is None:
94
+ decoder = AutoModelForCausalLM.from_config(config.decoder)
95
+
96
+ self.encoder = encoder
97
+ self.decoder = decoder
98
+
99
+ if self.encoder.config.to_dict() != self.config.encoder.to_dict():
100
+ logger.warning(
101
+ f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
102
+ f" {self.config.encoder}"
103
+ )
104
+ if self.decoder.config.to_dict() != self.config.decoder.to_dict():
105
+ logger.warning(
106
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
107
+ f" {self.config.decoder}"
108
+ )
109
+
110
+ # make sure that the individual model's config refers to the shared config
111
+ # so that the updates to the config will be synced
112
+ self.encoder.config = self.config.encoder
113
+ self.decoder.config = self.config.decoder
114
+
115
+ # get encoder output hidden size
116
+ self.encoder_output_dim = getattr(config.encoder, "output_hidden_size", config.encoder.hidden_size)
117
+ if (
118
+ self.encoder_output_dim != self.decoder.config.hidden_size
119
+ and self.decoder.config.cross_attention_hidden_size is None
120
+ ):
121
+ # encoder outputs might need to be projected to different dimension for decoder
122
+ self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
123
+
124
+ if self.encoder.get_output_embeddings() is not None:
125
+ raise ValueError(
126
+ f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
127
+ )
128
+ self.enc_loss_weight = config.ctc_weight
129
+ self.dec_loss_weight = 1 - config.ctc_weight
130
+ self.lsm_factor = config.lsm_factor
131
+
132
+ if config.shared_lm_head:
133
+ self.encoder.lm_head.weight = self.decoder.lm_head.weight
134
+
135
+ if (hasattr(config, "decoder_pos_emb_fixed") and config.decoder_pos_emb_fixed) or (
136
+ hasattr(config.decoder, "pos_emb_fixed") and config.decoder.pos_emb_fixed
137
+ ):
138
+ self.decoder.transformer.wte = AdaptiveEmbedding(
139
+ n_token=config.decoder.vocab_size,
140
+ d_embed=config.decoder.hidden_size,
141
+ d_proj=config.decoder.hidden_size,
142
+ cutoffs=[],
143
+ )
144
+ self.decoder.transformer.wpe = PositionalEmbedding(demb=config.decoder.hidden_size)
145
+ self.decoder.post_init()
146
+
147
+ self.encoder_logits = None
148
+ self.encoder_output_lens = None
149
+
150
+ @classmethod
151
+ def from_encoder_decoder_pretrained(
152
+ cls,
153
+ encoder_pretrained_model_name_or_path: str = None,
154
+ decoder_pretrained_model_name_or_path: str = None,
155
+ *model_args,
156
+ **kwargs,
157
+ ) -> PreTrainedModel:
158
+ kwargs_encoder = {
159
+ argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
160
+ }
161
+
162
+ kwargs_decoder = {
163
+ argument[len("decoder_") :]: value
164
+ for argument, value in kwargs.items()
165
+ if argument.startswith("decoder_") and argument != "decoder_start_token_id"
166
+ }
167
+
168
+ # remove encoder, decoder kwargs from kwargs
169
+ for key in kwargs_encoder.keys():
170
+ del kwargs["encoder_" + key]
171
+ for key in kwargs_decoder.keys():
172
+ del kwargs["decoder_" + key]
173
+
174
+ # Load and initialize the encoder and decoder
175
+ # The distinction between encoder and decoder at the model level is made
176
+ # by the value of the flag `is_decoder` that we need to set correctly.
177
+ encoder = kwargs_encoder.pop("model", None)
178
+ if encoder is None:
179
+ if encoder_pretrained_model_name_or_path is None:
180
+ raise ValueError(
181
+ "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
182
+ "to be defined."
183
+ )
184
+
185
+ if "config" not in kwargs_encoder:
186
+ encoder_config, kwargs_encoder = AutoConfig.from_pretrained(
187
+ encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
188
+ )
189
+
190
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
191
+ logger.info(
192
+ f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
193
+ "from a decoder model. Cross-attention and casual mask are disabled."
194
+ )
195
+ encoder_config.is_decoder = False
196
+ encoder_config.add_cross_attention = False
197
+
198
+ kwargs_encoder["config"] = encoder_config
199
+
200
+ encoder = CustomAutoModelForCTC.from_pretrained(
201
+ encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
202
+ )
203
+ encoder.register_forward_hook(wav2vec2_for_ctc_forward_hook)
204
+
205
+ decoder = kwargs_decoder.pop("model", None)
206
+ if decoder is None:
207
+ if decoder_pretrained_model_name_or_path is None:
208
+ raise ValueError(
209
+ "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
210
+ "to be defined."
211
+ )
212
+
213
+ if "config" not in kwargs_decoder:
214
+ decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
215
+ decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
216
+ )
217
+
218
+ if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
219
+ logger.info(
220
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
221
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
222
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
223
+ )
224
+ decoder_config.is_decoder = True
225
+ decoder_config.add_cross_attention = True
226
+
227
+ kwargs_decoder["config"] = decoder_config
228
+
229
+ if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
230
+ logger.warning(
231
+ f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
232
+ f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
233
+ "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
234
+ "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
235
+ "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
236
+ )
237
+
238
+ decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
239
+
240
+ # instantiate config with corresponding kwargs
241
+ config = JointCTCAttentionEncoderDecoderConfig.from_encoder_decoder_configs(
242
+ encoder.config, decoder.config, **kwargs
243
+ )
244
+
245
+ # make sure input & output embeddings is not tied
246
+ config.tie_word_embeddings = False
247
+ return cls(encoder=encoder, decoder=decoder, config=config)
248
+
249
+ def forward(
250
+ self,
251
+ inputs: Optional[torch.FloatTensor] = None,
252
+ attention_mask: Optional[torch.FloatTensor] = None,
253
+ decoder_input_ids: Optional[torch.LongTensor] = None,
254
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
255
+ encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
256
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
257
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
258
+ labels: Optional[torch.LongTensor] = None,
259
+ use_cache: Optional[bool] = None,
260
+ output_attentions: Optional[bool] = None,
261
+ output_hidden_states: Optional[bool] = None,
262
+ input_values: Optional[torch.FloatTensor] = None,
263
+ input_features: Optional[torch.FloatTensor] = None,
264
+ return_dict: Optional[bool] = None,
265
+ **kwargs,
266
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutputLosses]:
267
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
268
+
269
+ kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
270
+
271
+ kwargs_decoder = {
272
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
273
+ }
274
+
275
+ if encoder_outputs is None:
276
+ if inputs is None:
277
+ if input_values is not None and input_features is not None:
278
+ raise ValueError("You cannot specify both input_values and input_features at the same time")
279
+ elif input_values is not None:
280
+ inputs = input_values
281
+ elif input_features is not None:
282
+ inputs = input_features
283
+ else:
284
+ raise ValueError("You have to specify either input_values or input_features")
285
+
286
+ encoder_outputs = self.encoder(
287
+ inputs,
288
+ attention_mask=attention_mask,
289
+ output_attentions=output_attentions,
290
+ output_hidden_states=output_hidden_states,
291
+ return_dict=return_dict,
292
+ labels=labels,
293
+ **kwargs_encoder,
294
+ )
295
+ elif isinstance(encoder_outputs, tuple):
296
+ encoder_outputs = CausalLMOutput(*encoder_outputs)
297
+
298
+ encoder_hidden_states = encoder_outputs.last_hidden_state
299
+
300
+ # optionally project encoder_hidden_states
301
+ if (
302
+ self.encoder_output_dim != self.decoder.config.hidden_size
303
+ and self.decoder.config.cross_attention_hidden_size is None
304
+ ):
305
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
306
+
307
+ # compute correct encoder attention mask
308
+ if attention_mask is not None:
309
+ encoder_attention_mask = self.encoder._get_feature_vector_attention_mask(
310
+ encoder_hidden_states.shape[1], attention_mask
311
+ )
312
+ else:
313
+ encoder_attention_mask = None
314
+
315
+ if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
316
+ decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
317
+
318
+ # Decode
319
+ decoder_outputs = self.decoder(
320
+ input_ids=decoder_input_ids,
321
+ attention_mask=decoder_attention_mask,
322
+ encoder_hidden_states=encoder_hidden_states,
323
+ encoder_attention_mask=encoder_attention_mask,
324
+ inputs_embeds=decoder_inputs_embeds,
325
+ output_attentions=output_attentions,
326
+ output_hidden_states=True
327
+ if hasattr(self.decoder, "head_weights") and len(self.decoder.head_weights) > 1
328
+ else output_hidden_states,
329
+ use_cache=use_cache,
330
+ past_key_values=past_key_values,
331
+ return_dict=return_dict,
332
+ **kwargs_decoder,
333
+ )
334
+
335
+ # Compute loss independent from decoder (as some shift the logits inside them)
336
+ loss = enc_loss = dec_loss = None
337
+
338
+ if labels is not None:
339
+ loss_fct = CrossEntropyLoss(label_smoothing=self.lsm_factor)
340
+ enc_loss = encoder_outputs.loss if return_dict else encoder_outputs[0]
341
+ if isinstance(self.decoder, GPT2LMMultiHeadModel) and len(self.decoder.head_weights) > 1:
342
+ dec_loss = torch.zeros_like(enc_loss)
343
+ lm_logits_per_layer = []
344
+ for index, lm_head, lm_weight in zip(
345
+ [*self.decoder.head_locations, -1],
346
+ [*self.decoder.additional_lm_heads, self.decoder.lm_head],
347
+ self.decoder.head_weights,
348
+ ):
349
+ lm_logits = lm_head(decoder_outputs.hidden_states[index])
350
+ dec_loss += lm_weight * loss_fct(
351
+ lm_logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1)
352
+ )
353
+ lm_logits_per_layer.append(lm_logits)
354
+ if self.decoder.config.average_logits:
355
+ decoder_outputs.logits = torch.matmul(
356
+ torch.stack(lm_logits_per_layer).T,
357
+ torch.tensor(self.decoder.head_weights, device=lm_logits_per_layer[-1].device),
358
+ ).T
359
+
360
+ else:
361
+ dec_logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
362
+ dec_loss = loss_fct(dec_logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
363
+ loss = self.enc_loss_weight * enc_loss + self.dec_loss_weight * dec_loss
364
+
365
+ if not return_dict:
366
+ if loss is not None:
367
+ return (loss,) + decoder_outputs + encoder_outputs
368
+ else:
369
+ return decoder_outputs + encoder_outputs
370
+
371
+ return Seq2SeqLMOutputLosses(
372
+ loss=loss,
373
+ enc_loss=enc_loss,
374
+ dec_loss=dec_loss,
375
+ logits=decoder_outputs.logits,
376
+ past_key_values=decoder_outputs.past_key_values,
377
+ decoder_hidden_states=decoder_outputs.hidden_states,
378
+ decoder_attentions=decoder_outputs.attentions,
379
+ cross_attentions=decoder_outputs.cross_attentions,
380
+ encoder_last_hidden_state=encoder_hidden_states,
381
+ encoder_hidden_states=encoder_outputs.hidden_states,
382
+ encoder_attentions=encoder_outputs.attentions,
383
+ encoder_logits=encoder_outputs.logits,
384
+ )
385
+
386
+ def _get_logits_processor(
387
+ self,
388
+ generation_config: GenerationConfigWithCTC,
389
+ input_ids_seq_length: int,
390
+ encoder_input_ids: torch.LongTensor,
391
+ prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
392
+ logits_processor: Optional[LogitsProcessorList],
393
+ ) -> LogitsProcessorList:
394
+ processors = super()._get_logits_processor(
395
+ generation_config, input_ids_seq_length, encoder_input_ids, prefix_allowed_tokens_fn, logits_processor
396
+ )
397
+ if generation_config.ctc_weight > 0:
398
+ if generation_config.num_beams <= 1:
399
+ processors.append(LogSoftmaxProcessor())
400
+ self.ctc_rescorer = CTCRescorerLogitsProcessor(
401
+ self.encoder_logits,
402
+ self.encoder_output_lens,
403
+ self.generation_config.pad_token_id,
404
+ self.generation_config.eos_token_id,
405
+ self.generation_config.ctc_margin,
406
+ self.generation_config.ctc_weight,
407
+ self.generation_config.num_beams,
408
+ )
409
+ processors.append(self.ctc_rescorer)
410
+ return processors
411
+
412
+ def _prepare_encoder_decoder_kwargs_for_generation(
413
+ self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
414
+ ) -> Dict[str, Any]:
415
+ self.encoder_output_lens = self.encoder._get_feat_extract_output_lengths(
416
+ model_kwargs["attention_mask"].sum(dim=1)
417
+ )
418
+ model_kwargs = super()._prepare_encoder_decoder_kwargs_for_generation(
419
+ inputs_tensor, model_kwargs, model_input_name
420
+ )
421
+ self.encoder_logits = model_kwargs["encoder_outputs"].logits
422
+ return model_kwargs
423
+
424
+ @staticmethod
425
+ def _expand_inputs_for_generation(
426
+ expand_size: int = 1,
427
+ is_encoder_decoder: bool = False,
428
+ input_ids: Optional[torch.LongTensor] = None,
429
+ **model_kwargs,
430
+ ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
431
+ """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
432
+
433
+ def _expand_dict_for_generation(dict_to_expand):
434
+ for key in dict_to_expand:
435
+ if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], torch.Tensor) and key != "loss":
436
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
437
+ return dict_to_expand
438
+
439
+ if input_ids is not None:
440
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
441
+
442
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
443
+
444
+ if is_encoder_decoder:
445
+ if model_kwargs.get("encoder_outputs") is None:
446
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
447
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
448
+ model_kwargs["encoder_outputs"].last_hidden_state = model_kwargs[
449
+ "encoder_outputs"
450
+ ].last_hidden_state.repeat_interleave(expand_size, dim=0)
451
+
452
+ return input_ids, model_kwargs
453
+
454
+ @torch.no_grad()
455
+ def generate(
456
+ self,
457
+ inputs: Optional[torch.Tensor] = None,
458
+ generation_config: Optional[GenerationConfig] = None,
459
+ logits_processor: Optional[LogitsProcessorList] = None,
460
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
461
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
462
+ synced_gpus: Optional[bool] = None,
463
+ assistant_model: Optional["PreTrainedModel"] = None,
464
+ streamer: Optional["BaseStreamer"] = None,
465
+ **kwargs,
466
+ ) -> Union[GenerateOutput, torch.LongTensor]:
467
+ output = super().generate(
468
+ inputs,
469
+ generation_config,
470
+ logits_processor,
471
+ stopping_criteria,
472
+ prefix_allowed_tokens_fn,
473
+ synced_gpus,
474
+ assistant_model,
475
+ streamer,
476
+ **kwargs,
477
+ )
478
+ self.encoder_logits = None
479
+ self.encoder_output_lens = None
480
+ return output
481
+
482
+
483
+ AutoConfig.register("joint_aed_ctc_speech-encoder-decoder", JointCTCAttentionEncoderDecoderConfig)
484
+ AutoModelForSpeechSeq2Seq.register(JointCTCAttentionEncoderDecoderConfig, JointCTCAttentionEncoderDecoder)
multi_head_gpt2.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.utils.checkpoint
5
+ from torch import nn
6
+ from torch.nn import CrossEntropyLoss
7
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
8
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
9
+ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
10
+
11
+
12
+ class GPT2MultiHeadConfig(GPT2Config):
13
+ model_type = "gpt2-multi-head"
14
+
15
+ def __init__(
16
+ self,
17
+ head_locations=None,
18
+ head_weights=None,
19
+ tie_additional_weights=False,
20
+ average_logits=False,
21
+ *args,
22
+ **kwargs,
23
+ ):
24
+ super().__init__(*args, **kwargs)
25
+ self.head_locations = head_locations
26
+ self.head_weights = head_weights
27
+ self.tie_additional_weights = tie_additional_weights
28
+ self.average_logits = average_logits
29
+
30
+
31
+ class GPT2LMMultiHeadModel(GPT2LMHeadModel):
32
+ config_class = GPT2MultiHeadConfig
33
+
34
+ def __init__(self, config: GPT2MultiHeadConfig):
35
+ super().__init__(config)
36
+ if config.head_locations is not None:
37
+ if not len(config.head_locations) + 1 == len(config.head_weights):
38
+ raise ValueError("The number of head locations should be equal to the number of head weights minus 1")
39
+ self.head_locations = config.head_locations
40
+ self.additional_lm_heads = nn.ModuleList(
41
+ [nn.Linear(config.n_embd, config.vocab_size, bias=False) for _ in config.head_locations]
42
+ )
43
+ self.head_weights = config.head_weights
44
+ else:
45
+ self.head_locations = []
46
+ self.additional_lm_heads = nn.ModuleList([])
47
+ self.head_weights = [1.0]
48
+ self.post_init()
49
+
50
+ def tie_weights(self):
51
+ """
52
+ Tie the weights between the input embeddings and the output embeddings.
53
+
54
+ If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
55
+ weights instead.
56
+ """
57
+ super().tie_weights()
58
+ if hasattr(self, "additional_lm_heads") and getattr(self.config, "tie_additional_weights", False):
59
+ input_embeddings = self.get_input_embeddings()
60
+ for classifier in self.additional_lm_heads:
61
+ if self.config.torchscript:
62
+ classifier.weight = nn.Parameter(input_embeddings.weight.clone())
63
+ else:
64
+ classifier.weight = input_embeddings.weight
65
+
66
+ if getattr(classifier, "bias", None) is not None:
67
+ classifier.bias.data = nn.functional.pad(
68
+ classifier.bias.data,
69
+ (
70
+ 0,
71
+ classifier.weight.shape[0] - classifier.bias.shape[0],
72
+ ),
73
+ "constant",
74
+ 0,
75
+ )
76
+ if hasattr(classifier, "out_features") and hasattr(input_embeddings, "num_embeddings"):
77
+ classifier.out_features = input_embeddings.num_embeddings
78
+
79
+ def forward(
80
+ self,
81
+ input_ids: Optional[torch.LongTensor] = None,
82
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
83
+ attention_mask: Optional[torch.FloatTensor] = None,
84
+ token_type_ids: Optional[torch.LongTensor] = None,
85
+ position_ids: Optional[torch.LongTensor] = None,
86
+ head_mask: Optional[torch.FloatTensor] = None,
87
+ inputs_embeds: Optional[torch.FloatTensor] = None,
88
+ encoder_hidden_states: Optional[torch.Tensor] = None,
89
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
90
+ labels: Optional[torch.LongTensor] = None,
91
+ use_cache: Optional[bool] = None,
92
+ output_attentions: Optional[bool] = None,
93
+ output_hidden_states: Optional[bool] = None,
94
+ return_dict: Optional[bool] = None,
95
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
96
+ r"""
97
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
98
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
99
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
100
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
101
+ """
102
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
103
+
104
+ transformer_outputs = self.transformer(
105
+ input_ids,
106
+ past_key_values=past_key_values,
107
+ attention_mask=attention_mask,
108
+ token_type_ids=token_type_ids,
109
+ position_ids=position_ids,
110
+ head_mask=head_mask,
111
+ inputs_embeds=inputs_embeds,
112
+ encoder_hidden_states=encoder_hidden_states,
113
+ encoder_attention_mask=encoder_attention_mask,
114
+ use_cache=use_cache,
115
+ output_attentions=output_attentions,
116
+ output_hidden_states=True,
117
+ return_dict=return_dict,
118
+ )
119
+ hidden_states = transformer_outputs[2]
120
+
121
+ # Set device for model parallelism
122
+ if self.model_parallel:
123
+ torch.cuda.set_device(self.transformer.first_device)
124
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
125
+
126
+ lm_logits = self.lm_head(hidden_states[-1])
127
+ loss = None
128
+ if labels is not None:
129
+ loss = torch.tensor(0.0, device=hidden_states[-1].device)
130
+ lm_logits = []
131
+ loss_fct = CrossEntropyLoss()
132
+
133
+ for index, lm_head, lm_weight in zip(
134
+ [*self.head_locations, -1],
135
+ [*self.additional_lm_heads, self.lm_head],
136
+ self.head_weights,
137
+ ):
138
+ lm_logits.append(lm_head(hidden_states[index]))
139
+ # Shift so that tokens < n predict n
140
+ shift_logits = lm_logits[-1][..., :-1, :].contiguous()
141
+ shift_labels = labels[..., 1:].contiguous()
142
+ # Flatten the tokens
143
+ loss += lm_weight * loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
144
+
145
+ if self.config.average_logits:
146
+ lm_logits = (torch.vstack(lm_logits) * torch.tensor(self.head_weights)).mean(dim=0)
147
+ else:
148
+ lm_logits = lm_logits[-1]
149
+ if not return_dict:
150
+ output = (lm_logits,) + transformer_outputs[1:]
151
+ return ((loss,) + output) if loss is not None else output
152
+
153
+ return CausalLMOutputWithCrossAttentions(
154
+ loss=loss,
155
+ logits=lm_logits,
156
+ past_key_values=transformer_outputs.past_key_values,
157
+ hidden_states=transformer_outputs.hidden_states,
158
+ attentions=transformer_outputs.attentions,
159
+ cross_attentions=transformer_outputs.cross_attentions,
160
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:239d2cf3c581c86eff7d96eb7eb3300a07948030cc7d56cd42a4af363a66a8f6
3
+ size 698154846
residual_clasiffier_gpt2.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.utils.checkpoint
5
+ from torch import nn
6
+ from torch.nn import CrossEntropyLoss
7
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
8
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
9
+ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
10
+
11
+
12
+ class GPT2ResidualsLMHeadConfig(GPT2Config):
13
+ model_type = "gpt2-residuals-head"
14
+
15
+ def __init__(self, connected_residuals=None, *args, **kwargs):
16
+ super().__init__(*args, **kwargs)
17
+ self.connected_residuals = connected_residuals
18
+
19
+
20
+ class GPT2ResidualsLMHeadModel(GPT2LMHeadModel):
21
+ config_class = GPT2ResidualsLMHeadConfig
22
+
23
+ def __init__(self, config: GPT2ResidualsLMHeadConfig):
24
+ super().__init__(config)
25
+ self.connected_residuals = config.connected_residuals
26
+ self.lm_head = nn.Linear(config.n_embd * len(self.connected_residuals), config.vocab_size, bias=False)
27
+ self.post_init()
28
+
29
+ def forward(
30
+ self,
31
+ input_ids: Optional[torch.LongTensor] = None,
32
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
33
+ attention_mask: Optional[torch.FloatTensor] = None,
34
+ token_type_ids: Optional[torch.LongTensor] = None,
35
+ position_ids: Optional[torch.LongTensor] = None,
36
+ head_mask: Optional[torch.FloatTensor] = None,
37
+ inputs_embeds: Optional[torch.FloatTensor] = None,
38
+ encoder_hidden_states: Optional[torch.Tensor] = None,
39
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
40
+ labels: Optional[torch.LongTensor] = None,
41
+ use_cache: Optional[bool] = None,
42
+ output_attentions: Optional[bool] = None,
43
+ output_hidden_states: Optional[bool] = None,
44
+ return_dict: Optional[bool] = None,
45
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
46
+ r"""
47
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
48
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
49
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
50
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
51
+ """
52
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
53
+
54
+ transformer_outputs = self.transformer(
55
+ input_ids,
56
+ past_key_values=past_key_values,
57
+ attention_mask=attention_mask,
58
+ token_type_ids=token_type_ids,
59
+ position_ids=position_ids,
60
+ head_mask=head_mask,
61
+ inputs_embeds=inputs_embeds,
62
+ encoder_hidden_states=encoder_hidden_states,
63
+ encoder_attention_mask=encoder_attention_mask,
64
+ use_cache=use_cache,
65
+ output_attentions=output_attentions,
66
+ output_hidden_states=True,
67
+ return_dict=return_dict,
68
+ )
69
+ hidden_states = transformer_outputs[2]
70
+
71
+ # Set device for model parallelism
72
+ if self.model_parallel:
73
+ torch.cuda.set_device(self.transformer.first_device)
74
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
75
+
76
+ hidden_states = torch.concat([hidden_states[index] for index in self.connected_residuals], dim=-1)
77
+ lm_logits = self.lm_head(hidden_states)
78
+
79
+ loss = None
80
+ if labels is not None:
81
+ # Shift so that tokens < n predict n
82
+ shift_logits = lm_logits[..., :-1, :].contiguous()
83
+ shift_labels = labels[..., 1:].contiguous()
84
+ # Flatten the tokens
85
+ loss_fct = CrossEntropyLoss()
86
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
87
+
88
+ if not return_dict:
89
+ output = (lm_logits,) + transformer_outputs[1:]
90
+ return ((loss,) + output) if loss is not None else output
91
+
92
+ return CausalLMOutputWithCrossAttentions(
93
+ loss=loss,
94
+ logits=lm_logits,
95
+ past_key_values=transformer_outputs.past_key_values,
96
+ hidden_states=transformer_outputs.hidden_states,
97
+ attentions=transformer_outputs.attentions,
98
+ cross_attentions=transformer_outputs.cross_attentions,
99
+ )