File size: 826 Bytes
40ed350 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
from transformers import PretrainedConfig
from typing import List
class AUNetConfig(PretrainedConfig):
model_type = "s2l8hModel"
def __init__(
self,
in_channels:int = 7,
out_channels:int = 6,
depth:int = 5,
spatial_attention:str = 'None',
growth_factor:int = 6,
interp_mode:str = 'bicubic',
up_mode:str = 'upsample',
ca_layer:bool = False,
**kwargs,
):
self.in_channels = in_channels
self.out_channels = out_channels
self.depth = depth
self.spatial_attention = spatial_attention
self.growth_factor = growth_factor
self.interp_mode = interp_mode
self.up_mode = up_mode
self.ca_layer = ca_layer
super().__init__(**kwargs) |