Spaces:
Build error
Build error
import tensorflow as tf | |
import tensorflow.experimental.numpy as tnp | |
from tensorflow.keras import Model | |
import raven_utils as rv | |
class RangeMask(Model): | |
def __init__(self): | |
super().__init__() | |
ranges = tf.tile(tf.range(rv.entity.INDEX[-1])[None], [rv.group.NO, 1]) | |
start_index = rv.entity.INDEX[:-1][:, None] | |
end_index = rv.entity.INDEX[1:][:, None] | |
self.mask = tnp.array((start_index <= ranges) & (ranges < end_index)) | |
def call(self, inputs): | |
return self.mask[inputs] |