|
from transformers import PretrainedConfig |
|
from typing import List |
|
|
|
class RRDBConfiguration(PretrainedConfig): |
|
model_type = "SRMRIModelRRDB" |
|
def __init__( |
|
self, |
|
nChannels=1, |
|
nDenseLayers=6, |
|
nInitFeat=6, |
|
GrowthRate=12, |
|
featureFusion=True, |
|
kernel_config=[3, 3, 3, 3], |
|
**kwargs): |
|
self.nChannels = nChannels |
|
self.nDenseLayers = nDenseLayers |
|
self.nInitFeat = nInitFeat |
|
self.GrowthRate = GrowthRate |
|
self.featureFusion = featureFusion |
|
self.kernel_config = kernel_config |
|
super().__init__(**kwargs) |
|
|
|
class UNetMSSConfiguration(PretrainedConfig): |
|
model_type = "SRMRIModelUNetMSS" |
|
def __init__( |
|
self, |
|
in_channels=1, |
|
n_classes=1, |
|
depth=3, |
|
wf=6, |
|
padding=True, |
|
batch_norm=False, |
|
up_mode='upconv', |
|
dropout=False, |
|
mss_level=2, |
|
mss_fromlatent=True, |
|
mss_up="trilinear", |
|
mss_interpb4=True, |
|
**kwargs): |
|
self.in_channels = in_channels |
|
self.n_classes = n_classes |
|
self.depth = depth |
|
self.wf = wf |
|
self.padding = padding |
|
self.batch_norm = batch_norm |
|
self.up_mode = up_mode |
|
self.dropout = dropout |
|
self.mss_level = mss_level |
|
self.mss_fromlatent = mss_fromlatent |
|
self.mss_up = mss_up |
|
self.mss_interpb4 = mss_interpb4 |
|
super().__init__(**kwargs) |