from transformers import PretrainedConfig | |
class Resnet50Config(PretrainedConfig): | |
# since we have an image classification task | |
# we need to put a model type that is close to our task | |
# don't worry this will not affect our model | |
#model_type = "MobileNetV1" | |
def __init__( | |
self, | |
num_classes=6, | |
**kwargs): | |
self.num_classes = num_classes | |
super().__init__(**kwargs) |