SRMRI_RRDB_Cross_Scale_Cross_Contrast / SRMRIModelsConfigs.py
venkatesh-thiru's picture
Upload model
75507f4 verified
raw
history blame
1.57 kB
from transformers import PretrainedConfig
from typing import List
class RRDBConfiguration(PretrainedConfig):
model_type = "SRMRIModelRRDB"
def __init__(
self,
nChannels=1,
nDenseLayers=6,
nInitFeat=6,
GrowthRate=12,
featureFusion=True,
kernel_config=[3, 3, 3, 3],
**kwargs):
self.nChannels = nChannels
self.nDenseLayers = nDenseLayers
self.nInitFeat = nInitFeat
self.GrowthRate = GrowthRate
self.featureFusion = featureFusion
self.kernel_config = kernel_config
super().__init__(**kwargs)
class UNetMSSConfiguration(PretrainedConfig):
model_type = "SRMRIModelUNetMSS"
def __init__(
self,
in_channels=1,
n_classes=1,
depth=3,
wf=6,
padding=True,
batch_norm=False,
up_mode='upconv',
dropout=False,
mss_level=2,
mss_fromlatent=True,
mss_up="trilinear",
mss_interpb4=True,
**kwargs):
self.in_channels = in_channels
self.n_classes = n_classes
self.depth = depth
self.wf = wf
self.padding = padding
self.batch_norm = batch_norm
self.up_mode = up_mode
self.dropout = dropout
self.mss_level = mss_level
self.mss_fromlatent = mss_fromlatent
self.mss_up = mss_up
self.mss_interpb4 = mss_interpb4
super().__init__(**kwargs)