|
|
|
"""Contains the implementation of generator described in EG3D.""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
from models.utils.official_stylegan2_model_helper import Generator as StyleGAN2Backbone |
|
from models.utils.official_stylegan2_model_helper import FullyConnectedLayer |
|
from models.utils.eg3d_superres import SuperresolutionHybrid2X |
|
from models.utils.eg3d_superres import SuperresolutionHybrid4X |
|
from models.utils.eg3d_superres import SuperresolutionHybrid8XDC |
|
from models.rendering.renderer import Renderer |
|
from models.rendering.feature_extractor import FeatureExtractor |
|
|
|
class EG3DGenerator(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
z_dim, |
|
c_dim, |
|
w_dim, |
|
img_resolution, |
|
img_channels, |
|
sr_num_fp16_res=0, |
|
mapping_kwargs={}, |
|
rendering_kwargs={}, |
|
sr_kwargs={}, |
|
**synthesis_kwargs, |
|
): |
|
super().__init__() |
|
self.z_dim = z_dim |
|
self.c_dim = c_dim |
|
self.w_dim = w_dim |
|
self.img_resolution = img_resolution |
|
self.img_channels = img_channels |
|
|
|
|
|
self.renderer = Renderer() |
|
|
|
|
|
self.feature_extractor = FeatureExtractor(ref_mode='tri_plane') |
|
|
|
|
|
self.backbone = StyleGAN2Backbone(z_dim, |
|
c_dim, |
|
w_dim, |
|
img_resolution=256, |
|
img_channels=32 * 3, |
|
mapping_kwargs=mapping_kwargs, |
|
**synthesis_kwargs) |
|
|
|
|
|
self.post_module = None |
|
|
|
|
|
self.post_neural_renderer = None |
|
sr_kwargs_total = dict( |
|
channels=32, |
|
img_resolution=img_resolution, |
|
sr_num_fp16_res=sr_num_fp16_res, |
|
sr_antialias=rendering_kwargs['sr_antialias'],) |
|
sr_kwargs_total.update(**sr_kwargs) |
|
if img_resolution == 128: |
|
self.post_neural_renderer = SuperresolutionHybrid2X( |
|
**sr_kwargs_total) |
|
elif img_resolution == 256: |
|
self.post_neural_renderer = SuperresolutionHybrid4X( |
|
**sr_kwargs_total) |
|
elif img_resolution == 512: |
|
self.post_neural_renderer = SuperresolutionHybrid8XDC( |
|
**sr_kwargs_total) |
|
else: |
|
raise TypeError(f'Unsupported image resolution: {img_resolution}!') |
|
|
|
|
|
self.fc_head = OSGDecoder( |
|
32, { |
|
'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), |
|
'decoder_output_dim': 32 |
|
}) |
|
|
|
|
|
self.neural_rendering_resolution = rendering_kwargs.get( |
|
'resolution', 64) |
|
self.rendering_kwargs = rendering_kwargs |
|
|
|
def mapping(self, |
|
z, |
|
c, |
|
truncation_psi=1, |
|
truncation_cutoff=None, |
|
update_emas=False): |
|
if self.rendering_kwargs['c_gen_conditioning_zero']: |
|
c = torch.zeros_like(c) |
|
return self.backbone.mapping(z, |
|
c * |
|
self.rendering_kwargs.get('c_scale', 0), |
|
truncation_psi=truncation_psi, |
|
truncation_cutoff=truncation_cutoff, |
|
update_emas=update_emas) |
|
|
|
def synthesis(self, |
|
wp, |
|
c, |
|
neural_rendering_resolution=None, |
|
update_emas=False, |
|
**synthesis_kwargs): |
|
cam2world_matrix = c[:, :16].view(-1, 4, 4) |
|
if self.rendering_kwargs.get('random_pose', False): |
|
cam2world_matrix = None |
|
|
|
if neural_rendering_resolution is None: |
|
neural_rendering_resolution = self.neural_rendering_resolution |
|
else: |
|
self.neural_rendering_resolution = neural_rendering_resolution |
|
|
|
tri_planes = self.backbone.synthesis(wp, |
|
update_emas=update_emas, |
|
**synthesis_kwargs) |
|
tri_planes = tri_planes.view(len(tri_planes), 3, -1, |
|
tri_planes.shape[-2], |
|
tri_planes.shape[-1]) |
|
|
|
rendering_result = self.renderer( |
|
wp=wp, |
|
feature_extractor=self.feature_extractor, |
|
rendering_options=self.rendering_kwargs, |
|
cam2world_matrix=cam2world_matrix, |
|
position_encoder=None, |
|
ref_representation=tri_planes, |
|
post_module=self.post_module, |
|
fc_head=self.fc_head) |
|
|
|
feature_samples = rendering_result['composite_rgb'] |
|
depth_samples = rendering_result['composite_depth'] |
|
|
|
|
|
N = wp.shape[0] |
|
H = W = self.neural_rendering_resolution |
|
feature_image = feature_samples.permute(0, 2, 1).reshape( |
|
N, feature_samples.shape[-1], H, W).contiguous() |
|
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) |
|
|
|
|
|
|
|
rgb_image = feature_image[:, :3] |
|
sr_image = self.post_neural_renderer( |
|
rgb_image, |
|
feature_image, |
|
wp, |
|
noise_mode=self.rendering_kwargs['superresolution_noise_mode'], |
|
**{ |
|
k: synthesis_kwargs[k] |
|
for k in synthesis_kwargs.keys() if k != 'noise_mode' |
|
}) |
|
|
|
return { |
|
'image': sr_image, |
|
'image_raw': rgb_image, |
|
'image_depth': depth_image |
|
} |
|
|
|
def sample(self, |
|
coordinates, |
|
directions, |
|
z, |
|
c, |
|
truncation_psi=1, |
|
truncation_cutoff=None, |
|
update_emas=False, |
|
**synthesis_kwargs): |
|
|
|
|
|
wp = self.mapping(z, |
|
c, |
|
truncation_psi=truncation_psi, |
|
truncation_cutoff=truncation_cutoff, |
|
update_emas=update_emas) |
|
tri_planes = self.backbone.synthesis(wp, |
|
update_emas=update_emas, |
|
**synthesis_kwargs) |
|
tri_planes = tri_planes.view(len(tri_planes), 3, -1, |
|
tri_planes.shape[-2], |
|
tri_planes.shape[-1]) |
|
result = self.renderer.get_sigma_rgb( |
|
wp=wp, |
|
points=coordinates, |
|
feature_extractor=self.feature_extractor, |
|
fc_head=self.fc_head, |
|
rendering_options=self.rendering_kwargs, |
|
ref_representation=tri_planes, |
|
post_module=self.post_module, |
|
ray_dirs=directions) |
|
|
|
return result |
|
|
|
def sample_mixed(self, |
|
coordinates, |
|
directions, |
|
wp, |
|
truncation_psi=1, |
|
truncation_cutoff=None, |
|
update_emas=False, |
|
**synthesis_kwargs): |
|
|
|
|
|
tri_planes = self.backbone.synthesis(wp, |
|
update_emas=update_emas, |
|
**synthesis_kwargs) |
|
tri_planes = tri_planes.view(len(tri_planes), 3, -1, |
|
tri_planes.shape[-2], |
|
tri_planes.shape[-1]) |
|
|
|
result = self.renderer.get_sigma_rgb( |
|
wp=wp, |
|
points=coordinates, |
|
feature_extractor=self.feature_extractor, |
|
fc_head=self.fc_head, |
|
rendering_options=self.rendering_kwargs, |
|
ref_representation=tri_planes, |
|
post_module=self.post_module, |
|
ray_dirs=directions) |
|
|
|
return result |
|
|
|
def forward(self, |
|
z, |
|
c, |
|
c_swapped=None, |
|
style_mixing_prob=0, |
|
truncation_psi=1, |
|
truncation_cutoff=None, |
|
neural_rendering_resolution=None, |
|
update_emas=False, |
|
sample_mixed=False, |
|
coordinates=None, |
|
**synthesis_kwargs): |
|
|
|
|
|
c_wp = c.clone() |
|
if c_swapped is not None: |
|
c_wp = c_swapped.clone() |
|
wp = self.mapping(z, |
|
c_wp, |
|
truncation_psi=truncation_psi, |
|
truncation_cutoff=truncation_cutoff, |
|
update_emas=update_emas) |
|
if style_mixing_prob > 0: |
|
cutoff = torch.empty([], dtype=torch.int64, |
|
device=wp.device).random_(1, wp.shape[1]) |
|
cutoff = torch.where( |
|
torch.rand([], device=wp.device) < style_mixing_prob, |
|
cutoff, torch.full_like(cutoff, wp.shape[1])) |
|
wp[:, cutoff:] = self.mapping(torch.randn_like(z), |
|
c, |
|
update_emas=update_emas)[:, cutoff:] |
|
if not sample_mixed: |
|
gen_output = self.synthesis( |
|
wp, |
|
c, |
|
update_emas=update_emas, |
|
neural_rendering_resolution=neural_rendering_resolution, |
|
**synthesis_kwargs) |
|
|
|
return { |
|
'wp': wp, |
|
'gen_output': gen_output, |
|
} |
|
|
|
else: |
|
|
|
assert coordinates is not None |
|
sample_sigma = self.sample_mixed(coordinates, |
|
torch.randn_like(coordinates), |
|
wp, |
|
update_emas=False)['sigma'] |
|
|
|
return { |
|
'wp': wp, |
|
'sample_sigma': sample_sigma |
|
} |
|
|
|
|
|
class OSGDecoder(nn.Module): |
|
"""Defines fully-connected layer head in EG3D.""" |
|
def __init__(self, n_features, options): |
|
super().__init__() |
|
self.hidden_dim = 64 |
|
|
|
self.net = nn.Sequential( |
|
FullyConnectedLayer(n_features, |
|
self.hidden_dim, |
|
lr_multiplier=options['decoder_lr_mul']), |
|
nn.Softplus(), |
|
FullyConnectedLayer(self.hidden_dim, |
|
1 + options['decoder_output_dim'], |
|
lr_multiplier=options['decoder_lr_mul'])) |
|
|
|
def forward(self, point_features, wp=None, dirs=None): |
|
|
|
|
|
|
|
point_features = point_features.mean(1) |
|
x = point_features |
|
|
|
N, M, C = x.shape |
|
x = x.view(N * M, C) |
|
|
|
x = self.net(x) |
|
x = x.view(N, M, -1) |
|
|
|
|
|
rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001 |
|
sigma = x[..., 0:1] |
|
|
|
return {'rgb': rgb, 'sigma': sigma} |
|
|