import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat import models from models import register from utils import make_coord, to_coordinates from mmcv.cnn import ConvModule from .blocks.CSPLayer import CSPLayer @register('funsr') class FUNSR(nn.Module): def __init__(self, encoder_spec, has_multiscale=False, neck=None, decoder=None, global_decoder=None, encoder_rgb=False, n_forward_times=1, encode_hr_coord=False, has_bn=True, encode_scale_ratio=False, local_unfold=False, weight_gen_func='nearest-exact', return_featmap=False, ): super().__init__() self.weight_gen_func = weight_gen_func # 'bilinear', 'nearest-exact' self.encoder = models.make(encoder_spec) self.encoder_out_dim = self.encoder.out_dim self.encode_scale_ratio = encode_scale_ratio self.has_multiscale = has_multiscale self.encoder_rgb = encoder_rgb self.encode_hr_coord = encode_hr_coord self.local_unfold = local_unfold self.return_featmap = return_featmap self.multiscale_layers = nn.ModuleList() if self.has_multiscale: # 48->24->12->6 conv_cfg = None if has_bn: norm_cfg = dict(type='BN', momentum=0.03, eps=0.001) else: norm_cfg = None act_cfg = dict(type='ReLU') num_blocks = [2, 4, 6] for n_idx in range(3): conv_layer = ConvModule( self.encoder_out_dim, self.encoder_out_dim*2, 3, stride=2, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg ) csp_layer = CSPLayer( self.encoder_out_dim*2, self.encoder_out_dim, num_blocks=num_blocks[n_idx], add_identity=True, use_depthwise=False, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) self.multiscale_layers.append(nn.Sequential(conv_layer, csp_layer)) if neck is not None: self.neck = models.make(neck, args={'in_dim': self.encoder_out_dim}) modulation_dim = self.neck.d_dim else: modulation_dim = self.encoder_out_dim self.n_forward_times = n_forward_times decoder_in_dim = 2 if self.encode_scale_ratio: decoder_in_dim += 2 if self.encode_hr_coord: decoder_in_dim += 2 if self.encoder_rgb: decoder_in_dim += 3 if decoder is not None: if self.local_unfold: self.down_dim_layer = nn.Conv2d(modulation_dim * 9, modulation_dim, 1) self.decoder = models.make(decoder, args={'modulation_dim': modulation_dim, 'in_dim': decoder_in_dim}) if global_decoder is not None: decoder_in_dim = 2 if self.encode_scale_ratio: decoder_in_dim += 2 if self.encoder_rgb: decoder_in_dim += 3 self.decoder_is_proj = global_decoder.get('is_proj', False) self.global_decoder = models.make(global_decoder, args={'modulation_dim': modulation_dim, 'in_dim': decoder_in_dim}) if self.decoder_is_proj: self.input_proj = nn.Linear(modulation_dim, modulation_dim) # self.output_proj = nn.Conv2d(6, 3, kernel_size=3, padding=1) self.output_proj = nn.Conv2d(6, 3, kernel_size=1) def forward_step(self, lr_img, func_map, global_func, rel_coord, lr_coord, hr_coord, scale_ratio_map=None, pred_rgb_value=None ): # Expand funcmap if self.local_unfold: b, c, h, w = func_map.shape func_map = F.unfold(func_map, 3, padding=1).view(b, c * 9, h, w) func_map = self.down_dim_layer(func_map) local_func_map = F.interpolate(func_map, size=hr_coord.shape[-2:], mode=self.weight_gen_func) rel_coord = repeat(rel_coord, 'b c h w -> (B b) c h w', B=lr_img.size(0)) hr_coord = repeat(hr_coord, 'c h w -> B c h w', B=lr_img.size(0)) local_input = rel_coord if self.encode_scale_ratio: local_input = torch.cat([local_input, scale_ratio_map], dim=1) if self.encode_hr_coord: local_input = torch.cat([local_input, hr_coord], dim=1) if self.encoder_rgb: if pred_rgb_value is None: pred_rgb_value = F.interpolate(lr_img, size=hr_coord.shape[-2:], mode='bicubic', align_corners=True) local_input = torch.cat((local_input, pred_rgb_value), dim=1) decoder_output = self.decoder(local_input, local_func_map) if hasattr(self, 'global_decoder'): if self.decoder_is_proj: global_func = self.input_proj(global_func) # B C global_func = repeat(global_func, 'B C -> B C H W', H=hr_coord.shape[2], W=hr_coord.shape[3]) global_input = hr_coord if self.encode_scale_ratio: global_input = torch.cat([global_input, scale_ratio_map], dim=1) if self.encoder_rgb: if pred_rgb_value is None: pred_rgb_value = F.interpolate(lr_img, size=hr_coord.shape[-2:], mode='bicubic', align_corners=True) global_input = torch.cat((global_input, pred_rgb_value), dim=1) global_decoder_output = self.global_decoder(global_input, global_func) returned_featmap = None if self.decoder_is_proj: if self.return_featmap: returned_featmap = torch.cat((global_decoder_output, decoder_output), dim=1) decoder_output = self.output_proj(torch.cat((global_decoder_output, decoder_output), dim=1)) else: decoder_output = global_decoder_output + decoder_output return decoder_output, returned_featmap def forward_backbone(self, x, keep_ori_feat=True): # x: img-BxCxHxW x = self.encoder(x) output_feats = [] if keep_ori_feat: output_feats.append(x) for layer in self.multiscale_layers: x = layer(x) output_feats.append(x) return output_feats def get_coordinate_map(self, x, hr_size): B, C, H, W = x.shape H_up, W_up = hr_size x_coord = to_coordinates(x.shape[-2:], return_map=True).to(x.device).permute(2, 0, 1) hr_coord = to_coordinates(hr_size, return_map=True).to(x.device).permute(2, 0, 1) # important! mode='nearest' gives inconsistent results # import pdb # pdb.set_trace() rel_grid = hr_coord - F.interpolate(x_coord.unsqueeze(0), size=hr_size, mode='nearest-exact') rel_grid[:, 0, :, :] *= H rel_grid[:, 1, :, :] *= W return rel_grid.contiguous().detach(), x_coord.contiguous().detach(), hr_coord.contiguous().detach() def forward(self, x, out_size): B, C, H_lr, W_lr = x.shape output_feats = self.forward_backbone(x) # List if hasattr(self, 'neck'): global_content, func_map = self.neck(output_feats) else: global_content = None func_map = output_feats[0] rel_coord, lr_coord, hr_coord = self.get_coordinate_map(x, out_size) scale_ratio_map = None if self.encode_scale_ratio: h_ratio = x.shape[2] / out_size[0] w_ratio = x.shape[3] / out_size[1] scale_ratio_map = torch.tensor([h_ratio, w_ratio]).view(1, -1, 1, 1).expand(B, -1, *out_size).to(x.device) pred_rgb_value = None return_pred_rgb_value = [] for n_time in range(self.n_forward_times): pred_rgb_value, returned_featmaps = self.forward_step( x, func_map, global_content, rel_coord, lr_coord, hr_coord, scale_ratio_map, pred_rgb_value ) return_pred_rgb_value.append(pred_rgb_value) if self.return_featmap: return return_pred_rgb_value, returned_featmaps return return_pred_rgb_value