|
from transformers import PretrainedConfig |
|
from typing import List |
|
|
|
|
|
class AUNetConfig(PretrainedConfig): |
|
model_type = "s2l8hModel" |
|
def __init__( |
|
self, |
|
in_channels:int = 7, |
|
out_channels:int = 6, |
|
depth:int = 5, |
|
spatial_attention:str = 'None', |
|
growth_factor:int = 6, |
|
interp_mode:str = 'bicubic', |
|
up_mode:str = 'upsample', |
|
ca_layer:bool = False, |
|
**kwargs, |
|
): |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.depth = depth |
|
self.spatial_attention = spatial_attention |
|
self.growth_factor = growth_factor |
|
self.interp_mode = interp_mode |
|
self.up_mode = up_mode |
|
self.ca_layer = ca_layer |
|
|
|
super().__init__(**kwargs) |