|
|
|
"""Contains the implementation of generator described in SGBEV3D.""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
from models.utils.official_stylegan2_model_helper import Generator as StyleGAN2Backbone |
|
from models.utils.official_stylegan3_model_helper import Generator as StyleGAN3Backbone |
|
from models.utils.unet import Generator as StyleGAN4Backbone |
|
from models.utils.official_stylegan2_model_helper import FullyConnectedLayer |
|
from models.utils.eg3d_superres import SuperresolutionHybrid2X |
|
from models.utils.eg3d_superres import SuperresolutionHybrid4X |
|
from models.utils.eg3d_superres import SuperresolutionHybrid4X_conststyle |
|
from models.utils.eg3d_superres import SuperresolutionHybrid8XDC |
|
from models.rendering.renderer import Renderer |
|
from models.rendering.feature_extractor import FeatureExtractor |
|
|
|
from models.utils.spade import SPADEGenerator |
|
|
|
class SGBEV3DGenerator(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
z_dim, |
|
c_dim, |
|
w_dim, |
|
semantic_nc, |
|
ngf, |
|
bev_grid_size, |
|
aspect_ratio, |
|
num_upsampling_layers, |
|
not_use_vae, |
|
norm_G, |
|
interpolate_sr, |
|
segmask=False, |
|
dim_seq='16,8,4,2,1', |
|
xyz_pe=False, |
|
reverse_xy=True, |
|
hidden_dim=64, |
|
additional_layer_num=0, |
|
block_num=5, |
|
layer_num=2, |
|
ff_input=False, |
|
ref_mode='bev_plane_clevr_256', |
|
sel_type=None, |
|
backbone_ver=2, |
|
img_resolution=256, |
|
bev_resolution=256, |
|
sr_num_fp16_res=0, |
|
mapping_kwargs={}, |
|
rendering_kwargs={}, |
|
sr_kwargs={}, |
|
**synthesis_kwargs |
|
): |
|
super().__init__() |
|
|
|
self.z_dim = z_dim |
|
self.interpolate_sr = interpolate_sr |
|
self.segmask = segmask |
|
|
|
|
|
self.renderer = Renderer() |
|
|
|
|
|
self.feature_extractor = FeatureExtractor(ref_mode=ref_mode, xyz_pe=xyz_pe, reverse_xy=reverse_xy) |
|
|
|
|
|
self.backbone = globals()[f'StyleGAN{backbone_ver}Backbone'](z_dim, c_dim, w_dim, img_resolution=bev_resolution, img_channels=32, label_nc=semantic_nc, use_sel=True, sel_type=sel_type, mapping_kwargs=mapping_kwargs, ff_input=ff_input, block_num=block_num, layer_num=layer_num, **synthesis_kwargs) |
|
|
|
|
|
self.post_module = None |
|
|
|
|
|
self.post_neural_renderer = None |
|
sr_kwargs_total = dict( |
|
channels=32, |
|
img_resolution=img_resolution, |
|
sr_num_fp16_res=sr_num_fp16_res, |
|
sr_antialias=rendering_kwargs['sr_antialias'],) |
|
sr_kwargs_total.update(**sr_kwargs) |
|
if img_resolution == 128: |
|
self.post_neural_renderer = SuperresolutionHybrid2X( |
|
**sr_kwargs_total) |
|
elif img_resolution == 256: |
|
self.post_neural_renderer = SuperresolutionHybrid4X_conststyle( |
|
**sr_kwargs_total) |
|
elif img_resolution == 512: |
|
self.post_neural_renderer = SuperresolutionHybrid8XDC( |
|
**sr_kwargs_total) |
|
else: |
|
raise TypeError(f'Unsupported image resolution: {img_resolution}!') |
|
|
|
|
|
self.fc_head = OSGDecoder( |
|
128 if xyz_pe else 64 , { |
|
'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), |
|
'decoder_output_dim': 32 |
|
}, |
|
hidden_dim=hidden_dim, |
|
additional_layer_num=additional_layer_num |
|
) |
|
|
|
|
|
self.neural_rendering_resolution = rendering_kwargs.get( |
|
'resolution', 64) |
|
self.rendering_kwargs = rendering_kwargs |
|
|
|
def mapping(self, |
|
z, |
|
c, |
|
truncation_psi=1, |
|
truncation_cutoff=None, |
|
update_emas=False): |
|
if self.rendering_kwargs['c_gen_conditioning_zero']: |
|
c = torch.zeros_like(c) |
|
return self.backbone.mapping(z, |
|
c * |
|
self.rendering_kwargs.get('c_scale', 0), |
|
truncation_psi=truncation_psi, |
|
truncation_cutoff=truncation_cutoff, |
|
update_emas=update_emas) |
|
|
|
def synthesis(self, |
|
wp, |
|
c, |
|
seg, |
|
neural_rendering_resolution=None, |
|
update_emas=False, |
|
**synthesis_kwargs): |
|
cam2world_matrix = c[:, :16].view(-1, 4, 4) |
|
if self.rendering_kwargs.get('random_pose', False): |
|
cam2world_matrix = None |
|
|
|
if neural_rendering_resolution is None: |
|
neural_rendering_resolution = self.neural_rendering_resolution |
|
else: |
|
self.neural_rendering_resolution = neural_rendering_resolution |
|
|
|
xy_planes = self.backbone.synthesis(wp, heatmap=seg, update_emas=update_emas, **synthesis_kwargs) |
|
if self.segmask: |
|
xy_planes = xy_planes * seg[:, 0, ...][:, None, ...] |
|
|
|
rendering_result = self.renderer( |
|
wp=wp, |
|
feature_extractor=self.feature_extractor, |
|
rendering_options=self.rendering_kwargs, |
|
cam2world_matrix=cam2world_matrix, |
|
position_encoder=None, |
|
ref_representation=xy_planes, |
|
post_module=self.post_module, |
|
fc_head=self.fc_head) |
|
|
|
feature_samples = rendering_result['composite_rgb'] |
|
depth_samples = rendering_result['composite_depth'] |
|
|
|
|
|
N = wp.shape[0] |
|
H = W = self.neural_rendering_resolution |
|
feature_image = feature_samples.permute(0, 2, 1).reshape( |
|
N, feature_samples.shape[-1], H, W).contiguous() |
|
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) |
|
|
|
|
|
|
|
rgb_image = feature_image[:, :3] |
|
if self.interpolate_sr: |
|
sr_image = torch.nn.functional.interpolate(rgb_image, size=(256, 256), mode='bilinear', align_corners=False) |
|
else: |
|
sr_image = self.post_neural_renderer( |
|
rgb_image, |
|
feature_image, |
|
|
|
noise_mode=self.rendering_kwargs['superresolution_noise_mode'], |
|
**{ |
|
k: synthesis_kwargs[k] |
|
for k in synthesis_kwargs.keys() if k != 'noise_mode' |
|
}) |
|
|
|
return { |
|
'image': sr_image, |
|
'image_raw': rgb_image, |
|
'image_depth': depth_image, |
|
'plane': xy_planes, |
|
'points': rendering_result['points'], |
|
'sigmas': rendering_result['sigmas'] |
|
} |
|
|
|
def sample(self, |
|
coordinates, |
|
directions, |
|
z, |
|
c, |
|
seg, |
|
truncation_psi=1, |
|
truncation_cutoff=None, |
|
update_emas=False, |
|
**synthesis_kwargs): |
|
|
|
|
|
cam2world_matrix = c[:, :16].view(-1, 4, 4) |
|
wp = self.mapping(z, c, truncation_psi=truncation_psi, |
|
truncation_cutoff=truncation_cutoff, |
|
update_emas=update_emas) |
|
xy_planes = self.backbone.synthesis(wp, heatmap=seg, update_emas=update_emas, **synthesis_kwargs) |
|
result = self.renderer.get_sigma_rgb( |
|
wp=wp, |
|
points=coordinates, |
|
feature_extractor=self.feature_extractor, |
|
fc_head=self.fc_head, |
|
rendering_options=self.rendering_kwargs, |
|
ref_representation=xy_planes, |
|
post_module=self.post_module, |
|
ray_dirs=directions, |
|
cam_matrix=cam2world_matrix) |
|
|
|
return result |
|
|
|
def sample_mixed(self, |
|
coordinates, |
|
directions, |
|
wp, c, seg, |
|
truncation_psi=1, |
|
truncation_cutoff=None, |
|
update_emas=False, |
|
**synthesis_kwargs): |
|
|
|
|
|
cam2world_matrix = c[:, :16].view(-1, 4, 4) |
|
xy_planes = self.backbone.synthesis(wp, heatmap=seg, update_emas=update_emas, **synthesis_kwargs) |
|
result = self.renderer.get_sigma_rgb( |
|
wp=wp, |
|
points=coordinates, |
|
feature_extractor=self.feature_extractor, |
|
fc_head=self.fc_head, |
|
rendering_options=self.rendering_kwargs, |
|
ref_representation=xy_planes, |
|
post_module=self.post_module, |
|
ray_dirs=directions, |
|
cam_matrix=cam2world_matrix) |
|
|
|
return result |
|
|
|
def forward(self, |
|
z, |
|
c, |
|
seg, |
|
c_swapped=None, |
|
style_mixing_prob=0, |
|
truncation_psi=1, |
|
truncation_cutoff=None, |
|
neural_rendering_resolution=None, |
|
update_emas=False, |
|
sample_mixed=False, |
|
coordinates=None, |
|
**synthesis_kwargs): |
|
|
|
|
|
c_wp = c.clone() |
|
if c_swapped is not None: |
|
c_wp = c_swapped.clone() |
|
wp = self.mapping(z, c_wp, truncation_psi=truncation_psi, |
|
truncation_cutoff=truncation_cutoff, |
|
update_emas=update_emas) |
|
|
|
|
|
|
|
if not sample_mixed: |
|
gen_output = self.synthesis( |
|
wp, |
|
c, |
|
seg, |
|
update_emas=update_emas, |
|
neural_rendering_resolution=neural_rendering_resolution, |
|
**synthesis_kwargs) |
|
|
|
return { |
|
'wp': z, |
|
'gen_output': gen_output, |
|
} |
|
|
|
else: |
|
|
|
assert coordinates is not None |
|
sample_sigma = self.sample_mixed(coordinates, |
|
torch.randn_like(coordinates), |
|
wp, c, seg, |
|
update_emas=False)['sigma'] |
|
|
|
return { |
|
'wp': z, |
|
'sample_sigma': sample_sigma |
|
} |
|
|
|
|
|
class OSGDecoder(nn.Module): |
|
"""Defines fully-connected layer head in EG3D.""" |
|
def __init__(self, n_features, options, hidden_dim=64, additional_layer_num=0): |
|
super().__init__() |
|
self.hidden_dim = hidden_dim |
|
|
|
lst = [] |
|
lst.append(FullyConnectedLayer(n_features, self.hidden_dim, lr_multiplier=options['decoder_lr_mul'])) |
|
lst.append(nn.Softplus()) |
|
for i in range(additional_layer_num): |
|
lst.append(FullyConnectedLayer(self.hidden_dim, self.hidden_dim, lr_multiplier=options['decoder_lr_mul'])) |
|
lst.append(nn.Softplus()) |
|
lst.append(FullyConnectedLayer(self.hidden_dim, 1+options['decoder_output_dim'], lr_multiplier=options['decoder_lr_mul'])) |
|
self.net = nn.Sequential(*lst) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, point_features, wp=None, dirs=None): |
|
|
|
|
|
|
|
|
|
N, R, K, C = point_features.shape |
|
x = point_features.reshape(-1, point_features.shape[-1]) |
|
x = self.net(x) |
|
x = x.view(N, -1, x.shape[-1]) |
|
|
|
|
|
rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001 |
|
sigma = x[..., 0:1] |
|
|
|
return {'rgb': rgb, 'sigma': sigma} |
|
|