|
import tensorflow as tf
|
|
from transformers.modeling_tf_utils import unpack_inputs
|
|
|
|
from transformers.modeling_tf_utils import TFPreTrainedModel
|
|
|
|
from .configuration_my_model import MyModelConfig
|
|
|
|
|
|
class TFMyModelPretrainedModel(TFPreTrainedModel):
|
|
config_class = MyModelConfig
|
|
|
|
|
|
class TFMyModel(TFMyModelPretrainedModel):
|
|
|
|
def __init__(self, config: MyModelConfig):
|
|
super().__init__(config)
|
|
self.config = config
|
|
|
|
self.n_layers = config.n_layers
|
|
self.hidden_dim = config.hidden_dim
|
|
self.linear = tf.keras.layers.Dense(units=config.n_layers)
|
|
|
|
@property
|
|
def dummy_inputs(self):
|
|
hidden = tf.zeros(shape=(1, self.config.hidden_dim))
|
|
dummy_inputs = {"hidden": hidden}
|
|
return dummy_inputs
|
|
|
|
@unpack_inputs
|
|
def call(
|
|
self,
|
|
hidden,
|
|
):
|
|
breakpoint()
|
|
self.linear(hidden)
|
|
|
|
|