Spaces:
Runtime error
Runtime error
File size: 2,319 Bytes
64e21e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import tensorflow as tf
from tensorflow import keras
from keras import layers
class PositionalEmbedding(layers.Layer):
def __init__(self, sequence_length, output_dim, **kwargs):
super().__init__(**kwargs)
self.position_embeddings = layers.Embedding(
input_dim=sequence_length, output_dim=output_dim
)
self.sequence_length = sequence_length
self.output_dim = output_dim
def call(self, inputs):
# The inputs are of shape: `(batch_size, frames, num_features)`
length = tf.shape(inputs)[1]
positions = tf.range(start=0, limit=length, delta=1)
embedded_positions = self.position_embeddings(positions)
return inputs + embedded_positions
def compute_mask(self, inputs, mask=None):
mask = tf.reduce_any(tf.cast(inputs, "bool"), axis=-1)
return mask
def get_config(self):
config = super().get_config()
config.update({
"sequence_length": self.sequence_length,
"output_dim": self.output_dim,
})
return config
class TransformerEncoder(layers.Layer):
def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.dense_dim = dense_dim
self.num_heads = num_heads
self.attention = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=embed_dim, dropout=0.3
)
self.dense_proj = keras.Sequential(
[layers.Dense(dense_dim, activation=tf.nn.gelu), layers.Dense(embed_dim),]
)
self.layernorm_1 = layers.LayerNormalization()
self.layernorm_2 = layers.LayerNormalization()
def call(self, inputs, mask=None):
if mask is not None:
mask = mask[:, tf.newaxis, :]
attention_output = self.attention(inputs, inputs, attention_mask=mask)
proj_input = self.layernorm_1(inputs + attention_output)
proj_output = self.dense_proj(proj_input)
return self.layernorm_2(proj_input + proj_output)
def get_config(self):
config = super().get_config()
config.update({
"embed_dim": self.embed_dim,
"dense_dim": self.dense_dim,
"num_heads": self.num_heads,
})
return config
|