File size: 1,568 Bytes
75507f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
from transformers import PretrainedConfig
from typing import List
class RRDBConfiguration(PretrainedConfig):
model_type = "SRMRIModelRRDB"
def __init__(
self,
nChannels=1,
nDenseLayers=6,
nInitFeat=6,
GrowthRate=12,
featureFusion=True,
kernel_config=[3, 3, 3, 3],
**kwargs):
self.nChannels = nChannels
self.nDenseLayers = nDenseLayers
self.nInitFeat = nInitFeat
self.GrowthRate = GrowthRate
self.featureFusion = featureFusion
self.kernel_config = kernel_config
super().__init__(**kwargs)
class UNetMSSConfiguration(PretrainedConfig):
model_type = "SRMRIModelUNetMSS"
def __init__(
self,
in_channels=1,
n_classes=1,
depth=3,
wf=6,
padding=True,
batch_norm=False,
up_mode='upconv',
dropout=False,
mss_level=2,
mss_fromlatent=True,
mss_up="trilinear",
mss_interpb4=True,
**kwargs):
self.in_channels = in_channels
self.n_classes = n_classes
self.depth = depth
self.wf = wf
self.padding = padding
self.batch_norm = batch_norm
self.up_mode = up_mode
self.dropout = dropout
self.mss_level = mss_level
self.mss_fromlatent = mss_fromlatent
self.mss_up = mss_up
self.mss_interpb4 = mss_interpb4
super().__init__(**kwargs) |