BerfScene / models /eg3d_generator_fv.py
3v324v23's picture
init
2f85de4
# python3.8
"""Contains the implementation of generator described in EG3D."""
import torch
import torch.nn as nn
import numpy as np
from models.utils.official_stylegan2_model_helper import MappingNetwork
from models.utils.official_stylegan2_model_helper import FullyConnectedLayer
from models.utils.eg3d_superres import SuperresolutionHybrid2X
from models.utils.eg3d_superres import SuperresolutionHybrid4X
from models.utils.eg3d_superres import SuperresolutionHybrid8XDC
from models.rendering.renderer import Renderer
from models.rendering.feature_extractor import FeatureExtractor
from models.volumegan_generator import FeatureVolume
from models.volumegan_generator import PositionEncoder
class EG3DGeneratorFV(nn.Module):
def __init__(
self,
# Input latent (Z) dimensionality.
z_dim,
# Conditioning label (C) dimensionality.
c_dim,
# Intermediate latent (W) dimensionality.
w_dim,
# Final output image resolution.
img_resolution,
# Number of output color channels.
img_channels,
# Number of fp16 layers of SR Network.
sr_num_fp16_res=0,
# Arguments for MappingNetwork.
mapping_kwargs={},
# Arguments for rendering.
rendering_kwargs={},
# Arguments for SuperResolution Network.
sr_kwargs={},
# Configs for FeatureVolume.
fv_cfg=dict(feat_res=32,
init_res=4,
base_channels=256,
output_channels=32,
w_dim=512),
# Configs for position encoder.
embed_cfg=dict(input_dim=3, max_freq_log2=10 - 1, N_freqs=10),
):
super().__init__()
self.z_dim = z_dim
self.c_dim = c_dim
self.w_dim = w_dim
self.img_resolution = img_resolution
self.img_channels = img_channels
# Set up mapping network.
# Here `num_ws = 2`: one for FeatureVolume Network injection and one for
# post_neural_renderer injection.
num_ws = 2
self.mapping_network = MappingNetwork(z_dim=z_dim,
c_dim=c_dim,
w_dim=w_dim,
num_ws=num_ws,
**mapping_kwargs)
# Set up the overall renderer.
self.renderer = Renderer()
# Set up the feature extractor.
self.feature_extractor = FeatureExtractor(ref_mode='feature_volume')
# Set up the reference representation generator.
self.ref_representation_generator = FeatureVolume(**fv_cfg)
# Set up the position encoder.
self.position_encoder = PositionEncoder(**embed_cfg)
# Set up the post module in the feature extractor.
self.post_module = None
# Set up the post neural renderer.
self.post_neural_renderer = None
sr_kwargs_total = dict(
channels=32,
img_resolution=img_resolution,
sr_num_fp16_res=sr_num_fp16_res,
sr_antialias=rendering_kwargs['sr_antialias'],)
sr_kwargs_total.update(**sr_kwargs)
if img_resolution == 128:
self.post_neural_renderer = SuperresolutionHybrid2X(
**sr_kwargs_total)
elif img_resolution == 256:
self.post_neural_renderer = SuperresolutionHybrid4X(
**sr_kwargs_total)
elif img_resolution == 512:
self.post_neural_renderer = SuperresolutionHybrid8XDC(
**sr_kwargs_total)
else:
raise TypeError(f'Unsupported image resolution: {img_resolution}!')
# Set up the fully-connected layer head.
self.fc_head = OSGDecoder(
32, {
'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1),
'decoder_output_dim': 32
})
# Set up some rendering related arguments.
self.neural_rendering_resolution = rendering_kwargs.get(
'resolution', 64)
self.rendering_kwargs = rendering_kwargs
def mapping(self,
z,
c,
truncation_psi=1,
truncation_cutoff=None,
update_emas=False):
if self.rendering_kwargs['c_gen_conditioning_zero']:
c = torch.zeros_like(c)
return self.mapping_network(z,
c *
self.rendering_kwargs.get('c_scale', 0),
truncation_psi=truncation_psi,
truncation_cutoff=truncation_cutoff,
update_emas=update_emas)
def synthesis(self,
wp,
c,
neural_rendering_resolution=None,
update_emas=False,
**synthesis_kwargs):
cam2world_matrix = c[:, :16].view(-1, 4, 4)
if self.rendering_kwargs.get('random_pose', False):
cam2world_matrix = None
if neural_rendering_resolution is None:
neural_rendering_resolution = self.neural_rendering_resolution
else:
self.neural_rendering_resolution = neural_rendering_resolution
feature_volume = self.ref_representation_generator(wp)
rendering_result = self.renderer(
wp=wp,
feature_extractor=self.feature_extractor,
rendering_options=self.rendering_kwargs,
cam2world_matrix=cam2world_matrix,
position_encoder=self.position_encoder,
ref_representation=feature_volume,
post_module=self.post_module,
fc_head=self.fc_head)
feature_samples = rendering_result['composite_rgb']
depth_samples = rendering_result['composite_depth']
# Reshape to keep consistent with 'raw' neural-rendered image.
N = wp.shape[0]
H = W = self.neural_rendering_resolution
feature_image = feature_samples.permute(0, 2, 1).reshape(
N, feature_samples.shape[-1], H, W).contiguous()
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
# Run the post neural renderer to get final image.
# Here, the post neural renderer is a super-resolution network.
rgb_image = feature_image[:, :3]
sr_image = self.post_neural_renderer(
rgb_image,
feature_image,
wp,
noise_mode=self.rendering_kwargs['superresolution_noise_mode'],
**{
k: synthesis_kwargs[k]
for k in synthesis_kwargs.keys() if k != 'noise_mode'
})
return {
'image': sr_image,
'image_raw': rgb_image,
'image_depth': depth_image
}
def sample(self,
coordinates,
directions,
z,
c,
truncation_psi=1,
truncation_cutoff=None,
update_emas=False):
# Compute RGB features, density for arbitrary 3D coordinates.
# Mostly used for extracting shapes.
wp = self.mapping_network(z,
c,
truncation_psi=truncation_psi,
truncation_cutoff=truncation_cutoff,
update_emas=update_emas)
feature_volume = self.ref_representation_generator(wp)
result = self.renderer.get_sigma_rgb(
wp=wp,
points=coordinates,
feature_extractor=self.feature_extractor,
fc_head=self.fc_head,
rendering_options=self.rendering_kwargs,
ref_representation=feature_volume,
position_encoder=self.position_encoder,
post_module=self.post_module,
ray_dirs=directions)
return result
def sample_mixed(self,
coordinates,
directions,
wp):
# Same as function `self.sample()`, but expects latent vectors 'wp'
# instead of Gaussian noise 'z'.
feature_volume = self.ref_representation_generator(wp)
result = self.renderer.get_sigma_rgb(
wp=wp,
points=coordinates,
feature_extractor=self.feature_extractor,
fc_head=self.fc_head,
rendering_options=self.rendering_kwargs,
ref_representation=feature_volume,
position_encoder=self.position_encoder,
post_module=self.post_module,
ray_dirs=directions)
return result
def forward(self,
z,
c,
c_swapped=None, # `c_swapped` is swapped pose conditioning.
style_mixing_prob=0,
truncation_psi=1,
truncation_cutoff=None,
neural_rendering_resolution=None,
update_emas=False,
sample_mixed=False,
coordinates=None,
**synthesis_kwargs):
# Render a batch of generated images.
c_wp = c.clone()
if c_swapped is not None:
c_wp = c_swapped.clone()
wp = self.mapping_network(z,
c_wp,
truncation_psi=truncation_psi,
truncation_cutoff=truncation_cutoff,
update_emas=update_emas)
if style_mixing_prob > 0:
cutoff = torch.empty([], dtype=torch.int64,
device=wp.device).random_(1, wp.shape[1])
cutoff = torch.where(
torch.rand([], device=wp.device) < style_mixing_prob, cutoff,
torch.full_like(cutoff, wp.shape[1]))
wp[:, cutoff:] = self.mapping_network(
torch.randn_like(z), c, update_emas=update_emas)[:, cutoff:]
if not sample_mixed:
gen_output = self.synthesis(
wp,
c,
update_emas=update_emas,
neural_rendering_resolution=neural_rendering_resolution,
**synthesis_kwargs)
return {
'wp': wp,
'gen_output': gen_output,
}
else:
# Only for density regularization in training process.
assert coordinates is not None
sample_sigma = self.sample_mixed(coordinates,
torch.randn_like(coordinates),
wp)['sigma']
return {
'wp': wp,
'sample_sigma': sample_sigma
}
class OSGDecoder(nn.Module):
"""Defines fully-connected layer head in EG3D."""
def __init__(self, n_features, options):
super().__init__()
self.hidden_dim = 64
self.net = nn.Sequential(
FullyConnectedLayer(n_features,
self.hidden_dim,
lr_multiplier=options['decoder_lr_mul']),
nn.Softplus(),
FullyConnectedLayer(self.hidden_dim,
1 + options['decoder_output_dim'],
lr_multiplier=options['decoder_lr_mul']))
def forward(self, point_features, wp=None, dirs=None):
# point_features.shape: [N, C, M, 1].
point_features = point_features.squeeze(-1)
point_features = point_features.permute(0, 2, 1)
x = point_features
N, M, C = x.shape
x = x.reshape(N * M, C)
x = self.net(x)
x = x.reshape(N, M, -1)
# Uses sigmoid clamping from MipNeRF
rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001
sigma = x[..., 0:1]
return {'rgb': rgb, 'sigma': sigma}