BerfScene / models /eg3d_generator.py
3v324v23's picture
init
2f85de4
raw
history blame
12.4 kB
# python3.8
"""Contains the implementation of generator described in EG3D."""
import torch
import torch.nn as nn
from models.utils.official_stylegan2_model_helper import Generator as StyleGAN2Backbone
from models.utils.official_stylegan2_model_helper import FullyConnectedLayer
from models.utils.eg3d_superres import SuperresolutionHybrid2X
from models.utils.eg3d_superres import SuperresolutionHybrid4X
from models.utils.eg3d_superres import SuperresolutionHybrid8XDC
from models.rendering.renderer import Renderer
from models.rendering.feature_extractor import FeatureExtractor
class EG3DGenerator(nn.Module):
def __init__(
self,
z_dim, # Input latent (Z) dimensionality.
c_dim, # Conditioning label (C) dimensionality.
w_dim, # Intermediate latent (W) dimensionality.
img_resolution, # Output resolution.
img_channels, # Number of output color channels.
sr_num_fp16_res=0, # Number of fp16 layers of SR Network.
mapping_kwargs={}, # Arguments for MappingNetwork.
rendering_kwargs={}, # Arguments for rendering.
sr_kwargs={}, # Arguments for SuperResolution Network.
**synthesis_kwargs, # Arguments for SynthesisNetwork.
):
super().__init__()
self.z_dim = z_dim
self.c_dim = c_dim
self.w_dim = w_dim
self.img_resolution = img_resolution
self.img_channels = img_channels
# Set up the overall renderer.
self.renderer = Renderer()
# Set up the feature extractor.
self.feature_extractor = FeatureExtractor(ref_mode='tri_plane')
# Set up the reference representation generator.
self.backbone = StyleGAN2Backbone(z_dim,
c_dim,
w_dim,
img_resolution=256,
img_channels=32 * 3,
mapping_kwargs=mapping_kwargs,
**synthesis_kwargs)
# Set up the post module in the feature extractor.
self.post_module = None
# Set up the post neural renderer.
self.post_neural_renderer = None
sr_kwargs_total = dict(
channels=32,
img_resolution=img_resolution,
sr_num_fp16_res=sr_num_fp16_res,
sr_antialias=rendering_kwargs['sr_antialias'],)
sr_kwargs_total.update(**sr_kwargs)
if img_resolution == 128:
self.post_neural_renderer = SuperresolutionHybrid2X(
**sr_kwargs_total)
elif img_resolution == 256:
self.post_neural_renderer = SuperresolutionHybrid4X(
**sr_kwargs_total)
elif img_resolution == 512:
self.post_neural_renderer = SuperresolutionHybrid8XDC(
**sr_kwargs_total)
else:
raise TypeError(f'Unsupported image resolution: {img_resolution}!')
# Set up the fully-connected layer head.
self.fc_head = OSGDecoder(
32, {
'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1),
'decoder_output_dim': 32
})
# Set up some rendering related arguments.
self.neural_rendering_resolution = rendering_kwargs.get(
'resolution', 64)
self.rendering_kwargs = rendering_kwargs
def mapping(self,
z,
c,
truncation_psi=1,
truncation_cutoff=None,
update_emas=False):
if self.rendering_kwargs['c_gen_conditioning_zero']:
c = torch.zeros_like(c)
return self.backbone.mapping(z,
c *
self.rendering_kwargs.get('c_scale', 0),
truncation_psi=truncation_psi,
truncation_cutoff=truncation_cutoff,
update_emas=update_emas)
def synthesis(self,
wp,
c,
neural_rendering_resolution=None,
update_emas=False,
**synthesis_kwargs):
cam2world_matrix = c[:, :16].view(-1, 4, 4)
if self.rendering_kwargs.get('random_pose', False):
cam2world_matrix = None
if neural_rendering_resolution is None:
neural_rendering_resolution = self.neural_rendering_resolution
else:
self.neural_rendering_resolution = neural_rendering_resolution
tri_planes = self.backbone.synthesis(wp,
update_emas=update_emas,
**synthesis_kwargs)
tri_planes = tri_planes.view(len(tri_planes), 3, -1,
tri_planes.shape[-2],
tri_planes.shape[-1])
rendering_result = self.renderer(
wp=wp,
feature_extractor=self.feature_extractor,
rendering_options=self.rendering_kwargs,
cam2world_matrix=cam2world_matrix,
position_encoder=None,
ref_representation=tri_planes,
post_module=self.post_module,
fc_head=self.fc_head)
feature_samples = rendering_result['composite_rgb']
depth_samples = rendering_result['composite_depth']
# Reshape to keep consistent with 'raw' neural-rendered image.
N = wp.shape[0]
H = W = self.neural_rendering_resolution
feature_image = feature_samples.permute(0, 2, 1).reshape(
N, feature_samples.shape[-1], H, W).contiguous()
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
# Run the post neural renderer to get final image.
# Here, the post neural renderer is a super-resolution network.
rgb_image = feature_image[:, :3]
sr_image = self.post_neural_renderer(
rgb_image,
feature_image,
wp,
noise_mode=self.rendering_kwargs['superresolution_noise_mode'],
**{
k: synthesis_kwargs[k]
for k in synthesis_kwargs.keys() if k != 'noise_mode'
})
return {
'image': sr_image,
'image_raw': rgb_image,
'image_depth': depth_image
}
def sample(self,
coordinates,
directions,
z,
c,
truncation_psi=1,
truncation_cutoff=None,
update_emas=False,
**synthesis_kwargs):
# Compute RGB features, density for arbitrary 3D coordinates.
# Mostly used for extracting shapes.
wp = self.mapping(z,
c,
truncation_psi=truncation_psi,
truncation_cutoff=truncation_cutoff,
update_emas=update_emas)
tri_planes = self.backbone.synthesis(wp,
update_emas=update_emas,
**synthesis_kwargs)
tri_planes = tri_planes.view(len(tri_planes), 3, -1,
tri_planes.shape[-2],
tri_planes.shape[-1])
result = self.renderer.get_sigma_rgb(
wp=wp,
points=coordinates,
feature_extractor=self.feature_extractor,
fc_head=self.fc_head,
rendering_options=self.rendering_kwargs,
ref_representation=tri_planes,
post_module=self.post_module,
ray_dirs=directions)
return result
def sample_mixed(self,
coordinates,
directions,
wp,
truncation_psi=1,
truncation_cutoff=None,
update_emas=False,
**synthesis_kwargs):
# Same as function `self.sample()`, but expects latent vectors 'wp'
# instead of Gaussian noise 'z'.
tri_planes = self.backbone.synthesis(wp,
update_emas=update_emas,
**synthesis_kwargs)
tri_planes = tri_planes.view(len(tri_planes), 3, -1,
tri_planes.shape[-2],
tri_planes.shape[-1])
result = self.renderer.get_sigma_rgb(
wp=wp,
points=coordinates,
feature_extractor=self.feature_extractor,
fc_head=self.fc_head,
rendering_options=self.rendering_kwargs,
ref_representation=tri_planes,
post_module=self.post_module,
ray_dirs=directions)
return result
def forward(self,
z,
c,
c_swapped=None, # `c_swapped` is swapped pose conditioning.
style_mixing_prob=0,
truncation_psi=1,
truncation_cutoff=None,
neural_rendering_resolution=None,
update_emas=False,
sample_mixed=False,
coordinates=None,
**synthesis_kwargs):
# Render a batch of generated images.
c_wp = c.clone()
if c_swapped is not None:
c_wp = c_swapped.clone()
wp = self.mapping(z,
c_wp,
truncation_psi=truncation_psi,
truncation_cutoff=truncation_cutoff,
update_emas=update_emas)
if style_mixing_prob > 0:
cutoff = torch.empty([], dtype=torch.int64,
device=wp.device).random_(1, wp.shape[1])
cutoff = torch.where(
torch.rand([], device=wp.device) < style_mixing_prob,
cutoff, torch.full_like(cutoff, wp.shape[1]))
wp[:, cutoff:] = self.mapping(torch.randn_like(z),
c,
update_emas=update_emas)[:, cutoff:]
if not sample_mixed:
gen_output = self.synthesis(
wp,
c,
update_emas=update_emas,
neural_rendering_resolution=neural_rendering_resolution,
**synthesis_kwargs)
return {
'wp': wp,
'gen_output': gen_output,
}
else:
# Only for density regularization in training process.
assert coordinates is not None
sample_sigma = self.sample_mixed(coordinates,
torch.randn_like(coordinates),
wp,
update_emas=False)['sigma']
return {
'wp': wp,
'sample_sigma': sample_sigma
}
class OSGDecoder(nn.Module):
"""Defines fully-connected layer head in EG3D."""
def __init__(self, n_features, options):
super().__init__()
self.hidden_dim = 64
self.net = nn.Sequential(
FullyConnectedLayer(n_features,
self.hidden_dim,
lr_multiplier=options['decoder_lr_mul']),
nn.Softplus(),
FullyConnectedLayer(self.hidden_dim,
1 + options['decoder_output_dim'],
lr_multiplier=options['decoder_lr_mul']))
def forward(self, point_features, wp=None, dirs=None):
# Aggregate features
# point_features.shape: [N, 3, M, C].
# Average across 'X, Y, Z' planes.
point_features = point_features.mean(1)
x = point_features
N, M, C = x.shape
x = x.view(N * M, C)
x = self.net(x)
x = x.view(N, M, -1)
# Uses sigmoid clamping from MipNeRF
rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001
sigma = x[..., 0:1]
return {'rgb': rgb, 'sigma': sigma}