File size: 960 Bytes
fb22435 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
from transformers import PretrainedConfig
"""
编写自定义配置时需要记住的三个重要事项如下:
必须继承自 PretrainedConfig,
PretrainedConfig 的 __init__ 方法必须接受任何 kwargs,
这些 kwargs 需要传递给超类的 __init__ 方法。
"""
class MyResnetConfig(PretrainedConfig):
model_type = "resnet"
def __init__(
self,
num_classes: int = 176, # 分类数
in_channels: int = 3, # 输入通道数
num_channels: int = 64, # 第一个卷积的输出通道数
num_residuals=None, # 每个残差块组合里残差块的数量
**kwargs,
):
self.num_classes = num_classes
self.in_channels = in_channels
self.num_channels = num_channels
if num_residuals is None:
num_residuals = [2, 2, 2, 2]
self.num_residuals = num_residuals
super().__init__(**kwargs)
|