from transformers import PretrainedConfig from typing import List class RRDBConfiguration(PretrainedConfig): model_type = "SRMRIModelRRDB" def __init__( self, nChannels=1, nDenseLayers=6, nInitFeat=6, GrowthRate=12, featureFusion=True, kernel_config=[3, 3, 3, 3], **kwargs): self.nChannels = nChannels self.nDenseLayers = nDenseLayers self.nInitFeat = nInitFeat self.GrowthRate = GrowthRate self.featureFusion = featureFusion self.kernel_config = kernel_config super().__init__(**kwargs) class UNetMSSConfiguration(PretrainedConfig): model_type = "SRMRIModelUNetMSS" def __init__( self, in_channels=1, n_classes=1, depth=3, wf=6, padding=True, batch_norm=False, up_mode='upconv', dropout=False, mss_level=2, mss_fromlatent=True, mss_up="trilinear", mss_interpb4=True, **kwargs): self.in_channels = in_channels self.n_classes = n_classes self.depth = depth self.wf = wf self.padding = padding self.batch_norm = batch_norm self.up_mode = up_mode self.dropout = dropout self.mss_level = mss_level self.mss_fromlatent = mss_fromlatent self.mss_up = mss_up self.mss_interpb4 = mss_interpb4 super().__init__(**kwargs)