taka-yamakoshi commited on
Commit
c2e2449
1 Parent(s): 450ed9d

remove flax

Browse files
Files changed (1) hide show
  1. custom_modeling_albert_flax.py +0 -493
custom_modeling_albert_flax.py DELETED
@@ -1,493 +0,0 @@
1
- from typing import Callable, Optional, Tuple
2
- from copy import deepcopy
3
-
4
- import numpy as np
5
-
6
- import flax
7
- import flax.linen as nn
8
- import jax
9
- import jax.numpy as jnp
10
- from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
11
- from flax.linen.attention import dot_product_attention_weights
12
- from flax.traverse_util import flatten_dict, unflatten_dict
13
- from jax import lax
14
-
15
- from transformers import AlbertConfig
16
- from transformers.models.albert.modeling_flax_albert import FlaxAlbertOnlyMLMHead, FlaxAlbertEmbeddings, FlaxAlbertPreTrainedModel
17
- from transformers.modeling_flax_outputs import (
18
- FlaxBaseModelOutput,
19
- FlaxBaseModelOutputWithPooling,
20
- FlaxMaskedLMOutput,
21
- FlaxMultipleChoiceModelOutput,
22
- FlaxQuestionAnsweringModelOutput,
23
- FlaxSequenceClassifierOutput,
24
- FlaxTokenClassifierOutput,
25
- )
26
- from transformers.utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
27
-
28
- from transformers.modeling_flax_utils import (
29
- ACT2FN,
30
- FlaxPreTrainedModel,
31
- append_call_sample_docstring,
32
- append_replace_return_docstrings,
33
- overwrite_call_docstring,
34
- )
35
-
36
- class CustomFlaxAlbertSelfAttention(nn.Module):
37
- config: AlbertConfig
38
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
39
-
40
- def setup(self):
41
- if self.config.hidden_size % self.config.num_attention_heads != 0:
42
- raise ValueError(
43
- "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
44
- " : {self.config.num_attention_heads}"
45
- )
46
-
47
- self.query = nn.Dense(
48
- self.config.hidden_size,
49
- dtype=self.dtype,
50
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
51
- )
52
- self.key = nn.Dense(
53
- self.config.hidden_size,
54
- dtype=self.dtype,
55
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
56
- )
57
- self.value = nn.Dense(
58
- self.config.hidden_size,
59
- dtype=self.dtype,
60
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
61
- )
62
- self.dense = nn.Dense(
63
- self.config.hidden_size,
64
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
65
- dtype=self.dtype,
66
- )
67
- self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
68
- self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
69
-
70
- def __call__(
71
- self,
72
- hidden_states,
73
- attention_mask,
74
- deterministic=True,
75
- output_attentions: bool = False,
76
- layer_id: int = None,
77
- interv_type: str = "swap",
78
- interv_dict: dict = {},
79
- ):
80
- head_dim = self.config.hidden_size // self.config.num_attention_heads
81
-
82
- query_states = self.query(hidden_states).reshape(
83
- hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
84
- )
85
- value_states = self.value(hidden_states).reshape(
86
- hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
87
- )
88
- key_states = self.key(hidden_states).reshape(
89
- hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
90
- )
91
-
92
- reps = {
93
- 'lay': hidden_states,
94
- 'qry': query_states,
95
- 'key': key_states,
96
- 'val': value_states,
97
- }
98
- if layer_id in interv_dict:
99
- interv = interv_dict[layer_id]
100
- for rep_name in ['lay','qry','key','val']:
101
- if rep_name in interv:
102
- new_state = deepcopy(reps[rep_name])
103
- for head_id, pos, swap_ids in interv[rep_name]:
104
- new_state[swap_ids[0],pos,head_id] = reps[rep_name][swap_ids[1],pos,head_id]
105
- new_state[swap_ids[1],pos,head_id] = reps[rep_name][swap_ids[0],pos,head_id]
106
- reps[rep_name] = deepcopy(new_state)
107
-
108
- hidden_states = deepcopy(reps['lay'])
109
- query_states = deepcopy(reps['qry'])
110
- key_states = deepcopy(reps['key'])
111
- value_states = deepcopy(reps['val'])
112
-
113
- # Convert the boolean attention mask to an attention bias.
114
- if attention_mask is not None:
115
- # attention mask in the form of attention bias
116
- attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
117
- attention_bias = lax.select(
118
- attention_mask > 0,
119
- jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
120
- jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
121
- )
122
- else:
123
- attention_bias = None
124
-
125
- dropout_rng = None
126
- if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
127
- dropout_rng = self.make_rng("dropout")
128
-
129
- attn_weights = dot_product_attention_weights(
130
- query_states,
131
- key_states,
132
- bias=attention_bias,
133
- dropout_rng=dropout_rng,
134
- dropout_rate=self.config.attention_probs_dropout_prob,
135
- broadcast_dropout=True,
136
- deterministic=deterministic,
137
- dtype=self.dtype,
138
- precision=None,
139
- )
140
-
141
- attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
142
- attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
143
-
144
- projected_attn_output = self.dense(attn_output)
145
- projected_attn_output = self.dropout(projected_attn_output, deterministic=deterministic)
146
- layernormed_attn_output = self.LayerNorm(projected_attn_output + hidden_states)
147
- outputs = (layernormed_attn_output, attn_weights) if output_attentions else (layernormed_attn_output,)
148
- return outputs
149
-
150
- class CustomFlaxAlbertLayer(nn.Module):
151
- config: AlbertConfig
152
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
153
-
154
- def setup(self):
155
- self.attention = CustomFlaxAlbertSelfAttention(self.config, dtype=self.dtype)
156
- self.ffn = nn.Dense(
157
- self.config.intermediate_size,
158
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
159
- dtype=self.dtype,
160
- )
161
- self.activation = ACT2FN[self.config.hidden_act]
162
- self.ffn_output = nn.Dense(
163
- self.config.hidden_size,
164
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
165
- dtype=self.dtype,
166
- )
167
- self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
168
- self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
169
-
170
- def __call__(
171
- self,
172
- hidden_states,
173
- attention_mask,
174
- deterministic: bool = True,
175
- output_attentions: bool = False,
176
- layer_id: int = None,
177
- interv_type: str = "swap",
178
- interv_dict: dict = {},
179
- ):
180
- attention_outputs = self.attention(
181
- hidden_states,
182
- attention_mask,
183
- deterministic=deterministic,
184
- output_attentions=output_attentions,
185
- layer_id=layer_id,
186
- interv_type=interv_type,
187
- interv_dict=interv_dict,
188
- )
189
- attention_output = attention_outputs[0]
190
- ffn_output = self.ffn(attention_output)
191
- ffn_output = self.activation(ffn_output)
192
- ffn_output = self.ffn_output(ffn_output)
193
- ffn_output = self.dropout(ffn_output, deterministic=deterministic)
194
- hidden_states = self.full_layer_layer_norm(ffn_output + attention_output)
195
-
196
- outputs = (hidden_states,)
197
-
198
- if output_attentions:
199
- outputs += (attention_outputs[1],)
200
- return outputs
201
-
202
- class CustomFlaxAlbertLayerCollection(nn.Module):
203
- config: AlbertConfig
204
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
205
-
206
- def setup(self):
207
- self.layers = [
208
- CustomFlaxAlbertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.inner_group_num)
209
- ]
210
-
211
- def __call__(
212
- self,
213
- hidden_states,
214
- attention_mask,
215
- deterministic: bool = True,
216
- output_attentions: bool = False,
217
- output_hidden_states: bool = False,
218
- layer_id: int = None,
219
- interv_type: str = "swap",
220
- interv_dict: dict = {},
221
- ):
222
- layer_hidden_states = ()
223
- layer_attentions = ()
224
-
225
- for layer_index, albert_layer in enumerate(self.layers):
226
- layer_output = albert_layer(
227
- hidden_states,
228
- attention_mask,
229
- deterministic=deterministic,
230
- output_attentions=output_attentions,
231
- layer_id=layer_id,
232
- interv_type=interv_type,
233
- interv_dict=interv_dict,
234
- )
235
- hidden_states = layer_output[0]
236
-
237
- if output_attentions:
238
- layer_attentions = layer_attentions + (layer_output[1],)
239
-
240
- if output_hidden_states:
241
- layer_hidden_states = layer_hidden_states + (hidden_states,)
242
-
243
- outputs = (hidden_states,)
244
- if output_hidden_states:
245
- outputs = outputs + (layer_hidden_states,)
246
- if output_attentions:
247
- outputs = outputs + (layer_attentions,)
248
- return outputs # last-layer hidden state, (layer hidden states), (layer attentions)
249
-
250
- class CustomFlaxAlbertLayerCollections(nn.Module):
251
- config: AlbertConfig
252
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
253
- layer_index: Optional[str] = None
254
-
255
- def setup(self):
256
- self.albert_layers = CustomFlaxAlbertLayerCollection(self.config, dtype=self.dtype)
257
-
258
- def __call__(
259
- self,
260
- hidden_states,
261
- attention_mask,
262
- deterministic: bool = True,
263
- output_attentions: bool = False,
264
- output_hidden_states: bool = False,
265
- layer_id: int = None,
266
- interv_type: str = "swap",
267
- interv_dict: dict = {},
268
- ):
269
- outputs = self.albert_layers(
270
- hidden_states,
271
- attention_mask,
272
- deterministic=deterministic,
273
- output_attentions=output_attentions,
274
- output_hidden_states=output_hidden_states,
275
- layer_id=layer_id,
276
- interv_type=interv_type,
277
- interv_dict=interv_dict,
278
- )
279
- return outputs
280
-
281
- class CustomFlaxAlbertLayerGroups(nn.Module):
282
- config: AlbertConfig
283
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
284
-
285
- def setup(self):
286
- self.layers = [
287
- CustomFlaxAlbertLayerCollections(self.config, name=str(i), layer_index=str(i), dtype=self.dtype)
288
- for i in range(self.config.num_hidden_groups)
289
- ]
290
-
291
- def __call__(
292
- self,
293
- hidden_states,
294
- attention_mask,
295
- deterministic: bool = True,
296
- output_attentions: bool = False,
297
- output_hidden_states: bool = False,
298
- return_dict: bool = True,
299
- interv_type: str = "swap",
300
- interv_dict: dict = {},
301
- ):
302
- all_attentions = () if output_attentions else None
303
- all_hidden_states = (hidden_states,) if output_hidden_states else None
304
-
305
- for i in range(self.config.num_hidden_layers):
306
- # Index of the hidden group
307
- group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
308
- layer_group_output = self.layers[group_idx](
309
- hidden_states,
310
- attention_mask,
311
- deterministic=deterministic,
312
- output_attentions=output_attentions,
313
- output_hidden_states=output_hidden_states,
314
- layer_id=i,
315
- interv_type=interv_type,
316
- interv_dict=interv_dict,
317
- )
318
- hidden_states = layer_group_output[0]
319
-
320
- if output_attentions:
321
- all_attentions = all_attentions + layer_group_output[-1]
322
-
323
- if output_hidden_states:
324
- all_hidden_states = all_hidden_states + (hidden_states,)
325
-
326
- if not return_dict:
327
- return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
328
- return FlaxBaseModelOutput(
329
- last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
330
- )
331
-
332
- class CustomFlaxAlbertEncoder(nn.Module):
333
- config: AlbertConfig
334
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
335
-
336
- def setup(self):
337
- self.embedding_hidden_mapping_in = nn.Dense(
338
- self.config.hidden_size,
339
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
340
- dtype=self.dtype,
341
- )
342
- self.albert_layer_groups = CustomFlaxAlbertLayerGroups(self.config, dtype=self.dtype)
343
-
344
- def __call__(
345
- self,
346
- hidden_states,
347
- attention_mask,
348
- deterministic: bool = True,
349
- output_attentions: bool = False,
350
- output_hidden_states: bool = False,
351
- return_dict: bool = True,
352
- interv_type: str = "swap",
353
- interv_dict: dict = {},
354
- ):
355
- hidden_states = self.embedding_hidden_mapping_in(hidden_states)
356
- return self.albert_layer_groups(
357
- hidden_states,
358
- attention_mask,
359
- deterministic=deterministic,
360
- output_attentions=output_attentions,
361
- output_hidden_states=output_hidden_states,
362
- interv_type=interv_type,
363
- interv_dict=interv_dict,
364
- )
365
-
366
- class CustomFlaxAlbertModule(nn.Module):
367
- config: AlbertConfig
368
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
369
- add_pooling_layer: bool = True
370
-
371
- def setup(self):
372
- self.embeddings = FlaxAlbertEmbeddings(self.config, dtype=self.dtype)
373
- self.encoder = CustomFlaxAlbertEncoder(self.config, dtype=self.dtype)
374
- if self.add_pooling_layer:
375
- self.pooler = nn.Dense(
376
- self.config.hidden_size,
377
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
378
- dtype=self.dtype,
379
- name="pooler",
380
- )
381
- self.pooler_activation = nn.tanh
382
- else:
383
- self.pooler = None
384
- self.pooler_activation = None
385
-
386
- def __call__(
387
- self,
388
- input_ids,
389
- attention_mask,
390
- token_type_ids: Optional[np.ndarray] = None,
391
- position_ids: Optional[np.ndarray] = None,
392
- deterministic: bool = True,
393
- output_attentions: bool = False,
394
- output_hidden_states: bool = False,
395
- return_dict: bool = True,
396
- interv_type: str = "swap",
397
- interv_dict: dict = {},
398
- ):
399
- # make sure `token_type_ids` is correctly initialized when not passed
400
- if token_type_ids is None:
401
- token_type_ids = jnp.zeros_like(input_ids)
402
-
403
- # make sure `position_ids` is correctly initialized when not passed
404
- if position_ids is None:
405
- position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
406
-
407
- hidden_states = self.embeddings(input_ids, token_type_ids, position_ids, deterministic=deterministic)
408
-
409
- outputs = self.encoder(
410
- hidden_states,
411
- attention_mask,
412
- deterministic=deterministic,
413
- output_attentions=output_attentions,
414
- output_hidden_states=output_hidden_states,
415
- return_dict=return_dict,
416
- interv_type=interv_type,
417
- interv_dict=interv_dict,
418
- )
419
- hidden_states = outputs[0]
420
- if self.add_pooling_layer:
421
- pooled = self.pooler(hidden_states[:, 0])
422
- pooled = self.pooler_activation(pooled)
423
- else:
424
- pooled = None
425
-
426
- if not return_dict:
427
- # if pooled is None, don't return it
428
- if pooled is None:
429
- return (hidden_states,) + outputs[1:]
430
- return (hidden_states, pooled) + outputs[1:]
431
-
432
- return FlaxBaseModelOutputWithPooling(
433
- last_hidden_state=hidden_states,
434
- pooler_output=pooled,
435
- hidden_states=outputs.hidden_states,
436
- attentions=outputs.attentions,
437
- )
438
-
439
- class CustomFlaxAlbertForMaskedLMModule(nn.Module):
440
- config: AlbertConfig
441
- dtype: jnp.dtype = jnp.float32
442
-
443
- def setup(self):
444
- self.albert = CustomFlaxAlbertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
445
- self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype)
446
-
447
- def __call__(
448
- self,
449
- input_ids,
450
- attention_mask,
451
- token_type_ids,
452
- position_ids,
453
- deterministic: bool = True,
454
- output_attentions: bool = False,
455
- output_hidden_states: bool = False,
456
- return_dict: bool = True,
457
- interv_type: str = "swap",
458
- interv_dict: dict = {},
459
- ):
460
- # Model
461
- outputs = self.albert(
462
- input_ids,
463
- attention_mask,
464
- token_type_ids,
465
- position_ids,
466
- deterministic=deterministic,
467
- output_attentions=output_attentions,
468
- output_hidden_states=output_hidden_states,
469
- return_dict=return_dict,
470
- interv_type=interv_type,
471
- interv_dict=interv_dict,
472
- )
473
-
474
- hidden_states = outputs[0]
475
- if self.config.tie_word_embeddings:
476
- shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
477
- else:
478
- shared_embedding = None
479
-
480
- # Compute the prediction scores
481
- logits = self.predictions(hidden_states, shared_embedding=shared_embedding)
482
-
483
- if not return_dict:
484
- return (logits,) + outputs[1:]
485
-
486
- return FlaxMaskedLMOutput(
487
- logits=logits,
488
- hidden_states=outputs.hidden_states,
489
- attentions=outputs.attentions,
490
- )
491
-
492
- class CustomFlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel):
493
- module_class = CustomFlaxAlbertForMaskedLMModule