import torch import torch.nn as nn def patch_first_conv(model, new_in_channels, default_in_channels=3, pretrained=True): """Change first convolution layer input channels. In case: in_channels == 1 or in_channels == 2 -> reuse original weights in_channels > 3 -> make random kaiming normal initialization """ # get first conv for module in model.modules(): if isinstance(module, nn.Conv2d) and module.in_channels == default_in_channels: break weight = module.weight.detach() module.in_channels = new_in_channels if not pretrained: module.weight = nn.parameter.Parameter( torch.Tensor( module.out_channels, new_in_channels // module.groups, *module.kernel_size ) ) module.reset_parameters() elif new_in_channels == 1: new_weight = weight.sum(1, keepdim=True) module.weight = nn.parameter.Parameter(new_weight) else: new_weight = torch.Tensor( module.out_channels, new_in_channels // module.groups, *module.kernel_size ) for i in range(new_in_channels): new_weight[:, i] = weight[:, i % default_in_channels] new_weight = new_weight * (default_in_channels / new_in_channels) module.weight = nn.parameter.Parameter(new_weight) def replace_strides_with_dilation(module, dilation_rate): """Patch Conv2d modules replacing strides with dilation""" for mod in module.modules(): if isinstance(mod, nn.Conv2d): mod.stride = (1, 1) mod.dilation = (dilation_rate, dilation_rate) kh, kw = mod.kernel_size mod.padding = ((kh // 2) * dilation_rate, (kh // 2) * dilation_rate) # Kostyl for EfficientNet if hasattr(mod, "static_padding"): mod.static_padding = nn.Identity()