File size: 670 Bytes
40ed350 b4f3d8a 40ed350 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
from transformers import PreTrainedModel
from AUNet import AUNet
from AUNetConfig import AUNetConfig
import torch
class s2l8hModel(PreTrainedModel):
config_class=AUNetConfig
def __init__(self, config):
super().__init__(config)
self.model = AUNet(
in_channels = config.in_channels, out_channels = config.out_channels,
depth = config.depth, spatial_attention = config.spatial_attention,
growth_factor = config.growth_factor, interp_mode = config.interp_mode,
up_mode = config.up_mode, ca_layer = config.ca_layer
)
def forward(self, MS, PAN):
return self.model.forward(MS, PAN) |