|
from typing import Callable, Optional, Tuple |
|
from copy import deepcopy |
|
|
|
import numpy as np |
|
|
|
import flax |
|
import flax.linen as nn |
|
import jax |
|
import jax.numpy as jnp |
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze |
|
from flax.linen.attention import dot_product_attention_weights |
|
from flax.traverse_util import flatten_dict, unflatten_dict |
|
from jax import lax |
|
|
|
from transformers import AlbertConfig |
|
from transformers.models.albert.modeling_flax_albert import FlaxAlbertOnlyMLMHead, FlaxAlbertEmbeddings, FlaxAlbertPreTrainedModel |
|
from transformers.modeling_flax_outputs import ( |
|
FlaxBaseModelOutput, |
|
FlaxBaseModelOutputWithPooling, |
|
FlaxMaskedLMOutput, |
|
FlaxMultipleChoiceModelOutput, |
|
FlaxQuestionAnsweringModelOutput, |
|
FlaxSequenceClassifierOutput, |
|
FlaxTokenClassifierOutput, |
|
) |
|
from transformers.utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging |
|
|
|
from transformers.modeling_flax_utils import ( |
|
ACT2FN, |
|
FlaxPreTrainedModel, |
|
append_call_sample_docstring, |
|
append_replace_return_docstrings, |
|
overwrite_call_docstring, |
|
) |
|
|
|
class CustomFlaxAlbertSelfAttention(nn.Module): |
|
config: AlbertConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
if self.config.hidden_size % self.config.num_attention_heads != 0: |
|
raise ValueError( |
|
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " |
|
" : {self.config.num_attention_heads}" |
|
) |
|
|
|
self.query = nn.Dense( |
|
self.config.hidden_size, |
|
dtype=self.dtype, |
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
|
) |
|
self.key = nn.Dense( |
|
self.config.hidden_size, |
|
dtype=self.dtype, |
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
|
) |
|
self.value = nn.Dense( |
|
self.config.hidden_size, |
|
dtype=self.dtype, |
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
|
) |
|
self.dense = nn.Dense( |
|
self.config.hidden_size, |
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
|
dtype=self.dtype, |
|
) |
|
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) |
|
|
|
def __call__( |
|
self, |
|
hidden_states, |
|
attention_mask, |
|
deterministic=True, |
|
output_attentions: bool = False, |
|
layer_id: int = None, |
|
interv_type: str = "swap", |
|
interv_dict: dict = {}, |
|
): |
|
head_dim = self.config.hidden_size // self.config.num_attention_heads |
|
|
|
query_states = self.query(hidden_states).reshape( |
|
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) |
|
) |
|
value_states = self.value(hidden_states).reshape( |
|
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) |
|
) |
|
key_states = self.key(hidden_states).reshape( |
|
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) |
|
) |
|
|
|
reps = { |
|
'lay': hidden_states, |
|
'qry': query_states, |
|
'key': key_states, |
|
'val': value_states, |
|
} |
|
if layer_id in interv_dict: |
|
interv = interv_dict[layer_id] |
|
for rep_name in ['lay','qry','key','val']: |
|
if rep_name in interv: |
|
new_state = deepcopy(reps[rep_name]) |
|
for head_id, pos, swap_ids in interv[rep_name]: |
|
new_state[swap_ids[0],pos,head_id] = reps[rep_name][swap_ids[1],pos,head_id] |
|
new_state[swap_ids[1],pos,head_id] = reps[rep_name][swap_ids[0],pos,head_id] |
|
reps[rep_name] = deepcopy(new_state) |
|
|
|
hidden_states = deepcopy(reps['lay']) |
|
query_states = deepcopy(reps['qry']) |
|
key_states = deepcopy(reps['key']) |
|
value_states = deepcopy(reps['val']) |
|
|
|
|
|
if attention_mask is not None: |
|
|
|
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) |
|
attention_bias = lax.select( |
|
attention_mask > 0, |
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), |
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), |
|
) |
|
else: |
|
attention_bias = None |
|
|
|
dropout_rng = None |
|
if not deterministic and self.config.attention_probs_dropout_prob > 0.0: |
|
dropout_rng = self.make_rng("dropout") |
|
|
|
attn_weights = dot_product_attention_weights( |
|
query_states, |
|
key_states, |
|
bias=attention_bias, |
|
dropout_rng=dropout_rng, |
|
dropout_rate=self.config.attention_probs_dropout_prob, |
|
broadcast_dropout=True, |
|
deterministic=deterministic, |
|
dtype=self.dtype, |
|
precision=None, |
|
) |
|
|
|
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) |
|
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) |
|
|
|
projected_attn_output = self.dense(attn_output) |
|
projected_attn_output = self.dropout(projected_attn_output, deterministic=deterministic) |
|
layernormed_attn_output = self.LayerNorm(projected_attn_output + hidden_states) |
|
outputs = (layernormed_attn_output, attn_weights) if output_attentions else (layernormed_attn_output,) |
|
return outputs |
|
|
|
class CustomFlaxAlbertLayer(nn.Module): |
|
config: AlbertConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.attention = CustomFlaxAlbertSelfAttention(self.config, dtype=self.dtype) |
|
self.ffn = nn.Dense( |
|
self.config.intermediate_size, |
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
|
dtype=self.dtype, |
|
) |
|
self.activation = ACT2FN[self.config.hidden_act] |
|
self.ffn_output = nn.Dense( |
|
self.config.hidden_size, |
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
|
dtype=self.dtype, |
|
) |
|
self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) |
|
|
|
def __call__( |
|
self, |
|
hidden_states, |
|
attention_mask, |
|
deterministic: bool = True, |
|
output_attentions: bool = False, |
|
layer_id: int = None, |
|
interv_type: str = "swap", |
|
interv_dict: dict = {}, |
|
): |
|
attention_outputs = self.attention( |
|
hidden_states, |
|
attention_mask, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
layer_id=layer_id, |
|
interv_type=interv_type, |
|
interv_dict=interv_dict, |
|
) |
|
attention_output = attention_outputs[0] |
|
ffn_output = self.ffn(attention_output) |
|
ffn_output = self.activation(ffn_output) |
|
ffn_output = self.ffn_output(ffn_output) |
|
ffn_output = self.dropout(ffn_output, deterministic=deterministic) |
|
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output) |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (attention_outputs[1],) |
|
return outputs |
|
|
|
class CustomFlaxAlbertLayerCollection(nn.Module): |
|
config: AlbertConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.layers = [ |
|
CustomFlaxAlbertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.inner_group_num) |
|
] |
|
|
|
def __call__( |
|
self, |
|
hidden_states, |
|
attention_mask, |
|
deterministic: bool = True, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
layer_id: int = None, |
|
interv_type: str = "swap", |
|
interv_dict: dict = {}, |
|
): |
|
layer_hidden_states = () |
|
layer_attentions = () |
|
|
|
for layer_index, albert_layer in enumerate(self.layers): |
|
layer_output = albert_layer( |
|
hidden_states, |
|
attention_mask, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
layer_id=layer_id, |
|
interv_type=interv_type, |
|
interv_dict=interv_dict, |
|
) |
|
hidden_states = layer_output[0] |
|
|
|
if output_attentions: |
|
layer_attentions = layer_attentions + (layer_output[1],) |
|
|
|
if output_hidden_states: |
|
layer_hidden_states = layer_hidden_states + (hidden_states,) |
|
|
|
outputs = (hidden_states,) |
|
if output_hidden_states: |
|
outputs = outputs + (layer_hidden_states,) |
|
if output_attentions: |
|
outputs = outputs + (layer_attentions,) |
|
return outputs |
|
|
|
class CustomFlaxAlbertLayerCollections(nn.Module): |
|
config: AlbertConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
layer_index: Optional[str] = None |
|
|
|
def setup(self): |
|
self.albert_layers = CustomFlaxAlbertLayerCollection(self.config, dtype=self.dtype) |
|
|
|
def __call__( |
|
self, |
|
hidden_states, |
|
attention_mask, |
|
deterministic: bool = True, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
layer_id: int = None, |
|
interv_type: str = "swap", |
|
interv_dict: dict = {}, |
|
): |
|
outputs = self.albert_layers( |
|
hidden_states, |
|
attention_mask, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
layer_id=layer_id, |
|
interv_type=interv_type, |
|
interv_dict=interv_dict, |
|
) |
|
return outputs |
|
|
|
class CustomFlaxAlbertLayerGroups(nn.Module): |
|
config: AlbertConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.layers = [ |
|
CustomFlaxAlbertLayerCollections(self.config, name=str(i), layer_index=str(i), dtype=self.dtype) |
|
for i in range(self.config.num_hidden_groups) |
|
] |
|
|
|
def __call__( |
|
self, |
|
hidden_states, |
|
attention_mask, |
|
deterministic: bool = True, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
interv_type: str = "swap", |
|
interv_dict: dict = {}, |
|
): |
|
all_attentions = () if output_attentions else None |
|
all_hidden_states = (hidden_states,) if output_hidden_states else None |
|
|
|
for i in range(self.config.num_hidden_layers): |
|
|
|
group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups)) |
|
layer_group_output = self.layers[group_idx]( |
|
hidden_states, |
|
attention_mask, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
layer_id=i, |
|
interv_type=interv_type, |
|
interv_dict=interv_dict, |
|
) |
|
hidden_states = layer_group_output[0] |
|
|
|
if output_attentions: |
|
all_attentions = all_attentions + layer_group_output[-1] |
|
|
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
if not return_dict: |
|
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) |
|
return FlaxBaseModelOutput( |
|
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions |
|
) |
|
|
|
class CustomFlaxAlbertEncoder(nn.Module): |
|
config: AlbertConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.embedding_hidden_mapping_in = nn.Dense( |
|
self.config.hidden_size, |
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
|
dtype=self.dtype, |
|
) |
|
self.albert_layer_groups = CustomFlaxAlbertLayerGroups(self.config, dtype=self.dtype) |
|
|
|
def __call__( |
|
self, |
|
hidden_states, |
|
attention_mask, |
|
deterministic: bool = True, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
interv_type: str = "swap", |
|
interv_dict: dict = {}, |
|
): |
|
hidden_states = self.embedding_hidden_mapping_in(hidden_states) |
|
return self.albert_layer_groups( |
|
hidden_states, |
|
attention_mask, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
interv_type=interv_type, |
|
interv_dict=interv_dict, |
|
) |
|
|
|
class CustomFlaxAlbertModule(nn.Module): |
|
config: AlbertConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
add_pooling_layer: bool = True |
|
|
|
def setup(self): |
|
self.embeddings = FlaxAlbertEmbeddings(self.config, dtype=self.dtype) |
|
self.encoder = CustomFlaxAlbertEncoder(self.config, dtype=self.dtype) |
|
if self.add_pooling_layer: |
|
self.pooler = nn.Dense( |
|
self.config.hidden_size, |
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
|
dtype=self.dtype, |
|
name="pooler", |
|
) |
|
self.pooler_activation = nn.tanh |
|
else: |
|
self.pooler = None |
|
self.pooler_activation = None |
|
|
|
def __call__( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
token_type_ids: Optional[np.ndarray] = None, |
|
position_ids: Optional[np.ndarray] = None, |
|
deterministic: bool = True, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
interv_type: str = "swap", |
|
interv_dict: dict = {}, |
|
): |
|
|
|
if token_type_ids is None: |
|
token_type_ids = jnp.zeros_like(input_ids) |
|
|
|
|
|
if position_ids is None: |
|
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) |
|
|
|
hidden_states = self.embeddings(input_ids, token_type_ids, position_ids, deterministic=deterministic) |
|
|
|
outputs = self.encoder( |
|
hidden_states, |
|
attention_mask, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
interv_type=interv_type, |
|
interv_dict=interv_dict, |
|
) |
|
hidden_states = outputs[0] |
|
if self.add_pooling_layer: |
|
pooled = self.pooler(hidden_states[:, 0]) |
|
pooled = self.pooler_activation(pooled) |
|
else: |
|
pooled = None |
|
|
|
if not return_dict: |
|
|
|
if pooled is None: |
|
return (hidden_states,) + outputs[1:] |
|
return (hidden_states, pooled) + outputs[1:] |
|
|
|
return FlaxBaseModelOutputWithPooling( |
|
last_hidden_state=hidden_states, |
|
pooler_output=pooled, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
class CustomFlaxAlbertForMaskedLMModule(nn.Module): |
|
config: AlbertConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.albert = CustomFlaxAlbertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) |
|
self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype) |
|
|
|
def __call__( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
token_type_ids, |
|
position_ids, |
|
deterministic: bool = True, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
interv_type: str = "swap", |
|
interv_dict: dict = {}, |
|
): |
|
|
|
outputs = self.albert( |
|
input_ids, |
|
attention_mask, |
|
token_type_ids, |
|
position_ids, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
interv_type=interv_type, |
|
interv_dict=interv_dict, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
if self.config.tie_word_embeddings: |
|
shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] |
|
else: |
|
shared_embedding = None |
|
|
|
|
|
logits = self.predictions(hidden_states, shared_embedding=shared_embedding) |
|
|
|
if not return_dict: |
|
return (logits,) + outputs[1:] |
|
|
|
return FlaxMaskedLMOutput( |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
class CustomFlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel): |
|
module_class = CustomFlaxAlbertForMaskedLMModule |
|
|