|
import jax.numpy as jnp |
|
|
|
from ...utils import logging |
|
from ..t5.modeling_flax_t5 import FlaxT5EncoderModel, FlaxT5ForConditionalGeneration, FlaxT5Model |
|
from .configuration_mt5 import MT5Config |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
_CONFIG_FOR_DOC = "T5Config" |
|
|
|
|
|
|
|
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: |
|
""" |
|
Shift input ids one token to the right. |
|
""" |
|
shifted_input_ids = jnp.zeros_like(input_ids) |
|
shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1]) |
|
shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id) |
|
|
|
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) |
|
return shifted_input_ids |
|
|
|
|
|
class FlaxMT5Model(FlaxT5Model): |
|
|
|
model_type = "mt5" |
|
config_class = MT5Config |
|
|
|
|
|
class FlaxMT5EncoderModel(FlaxT5EncoderModel): |
|
|
|
model_type = "mt5" |
|
config_class = MT5Config |
|
|
|
|
|
class FlaxMT5ForConditionalGeneration(FlaxT5ForConditionalGeneration): |
|
|
|
model_type = "mt5" |
|
config_class = MT5Config |
|
|