File size: 4,430 Bytes
2a13495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# import timm
import functools
import torch.utils.model_zoo as model_zoo

# from .resnet import resnet_encoders
# from .dpn import dpn_encoders
# from .vgg import vgg_encoders
# from .senet import senet_encoders
# from .densenet import densenet_encoders
# from .inceptionresnetv2 import inceptionresnetv2_encoders
# from .inceptionv4 import inceptionv4_encoders
# from .efficientnet import efficient_net_encoders
# from .mobilenet import mobilenet_encoders
# from .xception import xception_encoders
# from .timm_efficientnet import timm_efficientnet_encoders
# from .timm_resnest import timm_resnest_encoders
# from .timm_res2net import timm_res2net_encoders
# from .timm_regnet import timm_regnet_encoders
# from .timm_sknet import timm_sknet_encoders
# from .timm_mobilenetv3 import timm_mobilenetv3_encoders
# from .timm_gernet import timm_gernet_encoders
from .mix_transformer import mix_transformer_encoders

# from .timm_universal import TimmUniversalEncoder

# from ._preprocessing import preprocess_input

encoders = {}
# encoders.update(resnet_encoders)
# encoders.update(dpn_encoders)
# encoders.update(vgg_encoders)
# encoders.update(senet_encoders)
# encoders.update(densenet_encoders)
# encoders.update(inceptionresnetv2_encoders)
# encoders.update(inceptionv4_encoders)
# encoders.update(efficient_net_encoders)
# encoders.update(mobilenet_encoders)
# encoders.update(xception_encoders)
# encoders.update(timm_efficientnet_encoders)
# encoders.update(timm_resnest_encoders)
# encoders.update(timm_res2net_encoders)
# encoders.update(timm_regnet_encoders)
# encoders.update(timm_sknet_encoders)
# encoders.update(timm_mobilenetv3_encoders)
# encoders.update(timm_gernet_encoders)
encoders.update(mix_transformer_encoders)


def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):

    if name.startswith("tu-"):
        name = name[3:]
        encoder = TimmUniversalEncoder(
            name=name,
            in_channels=in_channels,
            depth=depth,
            output_stride=output_stride,
            pretrained=weights is not None,
            **kwargs,
        )
        return encoder

    try:
        Encoder = encoders[name]["encoder"]
    except KeyError:
        raise KeyError(
            "Wrong encoder name `{}`, supported encoders: {}".format(
                name, list(encoders.keys())
            )
        )

    params = encoders[name]["params"]
    params.update(depth=depth)
    encoder = Encoder(**params)

    if weights is not None:
        try:
            settings = encoders[name]["pretrained_settings"][weights]
        except KeyError:
            raise KeyError(
                "Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format(
                    weights, name, list(encoders[name]["pretrained_settings"].keys()),
                )
            )
        encoder.load_state_dict(model_zoo.load_url(settings["url"]))

    encoder.set_in_channels(in_channels, pretrained=weights is not None)
    if output_stride != 32:
        encoder.make_dilated(output_stride)

    return encoder


def get_encoder_names():
    return list(encoders.keys())


def get_preprocessing_params(encoder_name, pretrained="imagenet"):

    if encoder_name.startswith("tu-"):
        encoder_name = encoder_name[3:]
        if encoder_name not in timm.models.registry._model_has_pretrained:
            raise ValueError(
                f"{encoder_name} does not have pretrained weights and preprocessing parameters"
            )
        settings = timm.models.registry._model_default_cfgs[encoder_name]
    else:
        all_settings = encoders[encoder_name]["pretrained_settings"]
        if pretrained not in all_settings.keys():
            raise ValueError(
                "Available pretrained options {}".format(all_settings.keys())
            )
        settings = all_settings[pretrained]

    formatted_settings = {}
    formatted_settings["input_space"] = settings.get("input_space", "RGB")
    formatted_settings["input_range"] = list(settings.get("input_range", [0, 1]))
    formatted_settings["mean"] = list(settings.get("mean"))
    formatted_settings["std"] = list(settings.get("std"))

    return formatted_settings


def get_preprocessing_fn(encoder_name, pretrained="imagenet"):
    params = get_preprocessing_params(encoder_name, pretrained=pretrained)
    return functools.partial(preprocess_input, **params)