causal-intervention-demo / custom_modeling_albert_flax.py
taka-yamakoshi
impl interv
c1ed878
raw
history blame
18 kB
from typing import Callable, Optional, Tuple
from copy import deepcopy
import numpy as np
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from transformers import AlbertConfig
from transformers.models.albert.modeling_flax_albert import FlaxAlbertOnlyMLMHead, FlaxAlbertEmbeddings, FlaxAlbertPreTrainedModel
from transformers.modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPooling,
FlaxMaskedLMOutput,
FlaxMultipleChoiceModelOutput,
FlaxQuestionAnsweringModelOutput,
FlaxSequenceClassifierOutput,
FlaxTokenClassifierOutput,
)
from transformers.utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from transformers.modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_call_sample_docstring,
append_replace_return_docstrings,
overwrite_call_docstring,
)
class CustomFlaxAlbertSelfAttention(nn.Module):
config: AlbertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError(
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
" : {self.config.num_attention_heads}"
)
self.query = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.key = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.value = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(
self,
hidden_states,
attention_mask,
deterministic=True,
output_attentions: bool = False,
layer_id: int = None,
interv_type: str = "swap",
interv_dict: dict = {},
):
head_dim = self.config.hidden_size // self.config.num_attention_heads
query_states = self.query(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
value_states = self.value(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
key_states = self.key(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
reps = {
'lay': hidden_states,
'qry': query_states,
'key': key_states,
'val': value_states,
}
if layer_id in interv_dict:
interv = interv_dict[layer_id]
for rep_name in ['lay','qry','key','val']:
if rep_name in interv:
new_state = deepcopy(reps[rep_name])
for head_id, pos, swap_ids in interv[rep_name]:
new_state[swap_ids[0],pos,head_id] = reps[rep_name][swap_ids[1],pos,head_id]
new_state[swap_ids[1],pos,head_id] = reps[rep_name][swap_ids[0],pos,head_id]
reps[rep_name] = deepcopy(new_state)
hidden_states = deepcopy(reps['lay'])
query_states = deepcopy(reps['qry'])
key_states = deepcopy(reps['key'])
value_states = deepcopy(reps['val'])
# Convert the boolean attention mask to an attention bias.
if attention_mask is not None:
# attention mask in the form of attention bias
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
)
else:
attention_bias = None
dropout_rng = None
if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
dropout_rng = self.make_rng("dropout")
attn_weights = dot_product_attention_weights(
query_states,
key_states,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.config.attention_probs_dropout_prob,
broadcast_dropout=True,
deterministic=deterministic,
dtype=self.dtype,
precision=None,
)
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
projected_attn_output = self.dense(attn_output)
projected_attn_output = self.dropout(projected_attn_output, deterministic=deterministic)
layernormed_attn_output = self.LayerNorm(projected_attn_output + hidden_states)
outputs = (layernormed_attn_output, attn_weights) if output_attentions else (layernormed_attn_output,)
return outputs
class CustomFlaxAlbertLayer(nn.Module):
config: AlbertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.attention = CustomFlaxAlbertSelfAttention(self.config, dtype=self.dtype)
self.ffn = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.hidden_act]
self.ffn_output = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
layer_id: int = None,
interv_type: str = "swap",
interv_dict: dict = {},
):
attention_outputs = self.attention(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
layer_id=layer_id,
interv_type=interv_type,
interv_dict=interv_dict,
)
attention_output = attention_outputs[0]
ffn_output = self.ffn(attention_output)
ffn_output = self.activation(ffn_output)
ffn_output = self.ffn_output(ffn_output)
ffn_output = self.dropout(ffn_output, deterministic=deterministic)
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output)
outputs = (hidden_states,)
if output_attentions:
outputs += (attention_outputs[1],)
return outputs
class CustomFlaxAlbertLayerCollection(nn.Module):
config: AlbertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.layers = [
CustomFlaxAlbertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.inner_group_num)
]
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
layer_id: int = None,
interv_type: str = "swap",
interv_dict: dict = {},
):
layer_hidden_states = ()
layer_attentions = ()
for layer_index, albert_layer in enumerate(self.layers):
layer_output = albert_layer(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
layer_id=layer_id,
interv_type=interv_type,
interv_dict=interv_dict,
)
hidden_states = layer_output[0]
if output_attentions:
layer_attentions = layer_attentions + (layer_output[1],)
if output_hidden_states:
layer_hidden_states = layer_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if output_hidden_states:
outputs = outputs + (layer_hidden_states,)
if output_attentions:
outputs = outputs + (layer_attentions,)
return outputs # last-layer hidden state, (layer hidden states), (layer attentions)
class CustomFlaxAlbertLayerCollections(nn.Module):
config: AlbertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
layer_index: Optional[str] = None
def setup(self):
self.albert_layers = CustomFlaxAlbertLayerCollection(self.config, dtype=self.dtype)
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
layer_id: int = None,
interv_type: str = "swap",
interv_dict: dict = {},
):
outputs = self.albert_layers(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
layer_id=layer_id,
interv_type=interv_type,
interv_dict=interv_dict,
)
return outputs
class CustomFlaxAlbertLayerGroups(nn.Module):
config: AlbertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.layers = [
CustomFlaxAlbertLayerCollections(self.config, name=str(i), layer_index=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_groups)
]
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
interv_type: str = "swap",
interv_dict: dict = {},
):
all_attentions = () if output_attentions else None
all_hidden_states = (hidden_states,) if output_hidden_states else None
for i in range(self.config.num_hidden_layers):
# Index of the hidden group
group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
layer_group_output = self.layers[group_idx](
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
layer_id=i,
interv_type=interv_type,
interv_dict=interv_dict,
)
hidden_states = layer_group_output[0]
if output_attentions:
all_attentions = all_attentions + layer_group_output[-1]
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
class CustomFlaxAlbertEncoder(nn.Module):
config: AlbertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.embedding_hidden_mapping_in = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.albert_layer_groups = CustomFlaxAlbertLayerGroups(self.config, dtype=self.dtype)
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
interv_type: str = "swap",
interv_dict: dict = {},
):
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
return self.albert_layer_groups(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interv_type=interv_type,
interv_dict=interv_dict,
)
class CustomFlaxAlbertModule(nn.Module):
config: AlbertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
add_pooling_layer: bool = True
def setup(self):
self.embeddings = FlaxAlbertEmbeddings(self.config, dtype=self.dtype)
self.encoder = CustomFlaxAlbertEncoder(self.config, dtype=self.dtype)
if self.add_pooling_layer:
self.pooler = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
name="pooler",
)
self.pooler_activation = nn.tanh
else:
self.pooler = None
self.pooler_activation = None
def __call__(
self,
input_ids,
attention_mask,
token_type_ids: Optional[np.ndarray] = None,
position_ids: Optional[np.ndarray] = None,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
interv_type: str = "swap",
interv_dict: dict = {},
):
# make sure `token_type_ids` is correctly initialized when not passed
if token_type_ids is None:
token_type_ids = jnp.zeros_like(input_ids)
# make sure `position_ids` is correctly initialized when not passed
if position_ids is None:
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
hidden_states = self.embeddings(input_ids, token_type_ids, position_ids, deterministic=deterministic)
outputs = self.encoder(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interv_type=interv_type,
interv_dict=interv_dict,
)
hidden_states = outputs[0]
if self.add_pooling_layer:
pooled = self.pooler(hidden_states[:, 0])
pooled = self.pooler_activation(pooled)
else:
pooled = None
if not return_dict:
# if pooled is None, don't return it
if pooled is None:
return (hidden_states,) + outputs[1:]
return (hidden_states, pooled) + outputs[1:]
return FlaxBaseModelOutputWithPooling(
last_hidden_state=hidden_states,
pooler_output=pooled,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class CustomFlaxAlbertForMaskedLMModule(nn.Module):
config: AlbertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.albert = CustomFlaxAlbertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
interv_type: str = "swap",
interv_dict: dict = {},
):
# Model
outputs = self.albert(
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interv_type=interv_type,
interv_dict=interv_dict,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
else:
shared_embedding = None
# Compute the prediction scores
logits = self.predictions(hidden_states, shared_embedding=shared_embedding)
if not return_dict:
return (logits,) + outputs[1:]
return FlaxMaskedLMOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class CustomFlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel):
module_class = CustomFlaxAlbertForMaskedLMModule