from warnings import warn from typing import Callable, Optional, Tuple, Union import torch import torch.nn as nn from einops import rearrange, repeat from risk_biased.models.map_encoder import MapEncoderNN from risk_biased.models.mlp import MLP from risk_biased.models.cvae_params import CVAEParams from risk_biased.models.cvae_encoders import ( AbstractLatentDistribution, CVAEEncoder, BiasedEncoderNN, FutureEncoderNN, InferenceEncoderNN, ) from risk_biased.models.cvae_decoder import ( CVAEAccelerationDecoder, CVAEParametrizedDecoder, DecoderNN, ) from risk_biased.utils.cost import BaseCostTorch, get_cost from risk_biased.utils.loss import ( reconstruction_loss, risk_loss_function, ) from risk_biased.models.latent_distributions import ( GaussianLatentDistribution, QuantizedDistributionCreator, AbstractLatentDistribution, ) from risk_biased.utils.metrics import FDE, minFDE from risk_biased.utils.risk import AbstractMonteCarloRiskEstimator class InferenceBiasedCVAE(nn.Module): """CVAE with a biased encoder module for risk-biased trajectory forecasting. Args: absolute_encoder: encoder model for the absolute positions of the agents map_encoder: encoder model for map objects biased_encoder: biased encoder that uses past and auxiliary input, inference_encoder: inference encoder that uses only past, decoder: CVAE decoder model prior_distribution: prior distribution for the latent space. """ def __init__( self, absolute_encoder: MLP, map_encoder: MapEncoderNN, biased_encoder: CVAEEncoder, inference_encoder: CVAEEncoder, decoder: CVAEAccelerationDecoder, prior_distribution: AbstractLatentDistribution, ) -> None: super().__init__() self.biased_encoder = biased_encoder self.inference_encoder = inference_encoder self.decoder = decoder self.map_encoder = map_encoder self.absolute_encoder = absolute_encoder self.prior_distribution = prior_distribution def cvae_parameters(self, recurse: bool = True): """Define an iterator over all the parameters related to the cvae.""" yield from self.absolute_encoder.parameters(recurse=recurse) yield from self.map_encoder.parameters(recurse=recurse) yield from self.inference_encoder.parameters(recurse=recurse) yield from self.decoder.parameters(recurse=recurse) def biased_parameters(self, recurse: bool = True): """Define an iterator over only the parameters related to the biaser.""" yield from self.biased_encoder.biased_parameters(recurse=recurse) def forward( self, x: torch.Tensor, mask_x: torch.Tensor, map: torch.Tensor, mask_map: torch.Tensor, offset: torch.Tensor, *, x_ego: Optional[torch.Tensor] = None, y_ego: Optional[torch.Tensor] = None, risk_level: Optional[torch.Tensor] = None, n_samples: int = 0, ) -> Tuple[torch.Tensor, AbstractLatentDistribution]: """Forward function that outputs a noisy reconstruction of y and parameters of latent posterior distribution Args: x: (batch_size, num_agents, num_steps, state_dim) tensor of history mask_x: (batch_size, num_agents, num_steps) tensor of bool mask map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects mask_map: (batch_size, num_objects, object_sequence_length) tensor of bool mask offset : (batch_size, num_agents, state_dim) offset position from ego. Defaults to None. x_ego: (batch_size, 1, num_steps, state_dim) ego history y_ego: (batch_size, 1, num_steps_future, state_dim) ego future risk_level (optional): (batch_size, num_agents) tensor of risk levels desired for future trajectories. Defaults to None. n_samples (optional): number of samples to predict, (if 0 one sample with no extra dimension). Defaults to 0. Returns: noisy reconstruction y of size (batch_size, num_agents, num_steps_future, state_dim), as well as weights of the samples and the latent distribution. No bias is applied to encoder without offset or risk. """ encoded_map = self.map_encoder(map, mask_map) mask_map = mask_map.any(-1) encoded_absolute = self.absolute_encoder(offset) if risk_level is not None: biased_latent_distribution = self.biased_encoder( x, mask_x, encoded_absolute, encoded_map, mask_map, x_ego=x_ego, y_ego=y_ego, offset=offset, risk_level=risk_level, ) inference_latent_distribution = self.inference_encoder( x, mask_x, encoded_absolute, encoded_map, mask_map, ) latent_distribution = inference_latent_distribution.average( biased_latent_distribution, risk_level.unsqueeze(-1) ) else: latent_distribution = self.inference_encoder( x, mask_x, encoded_absolute, encoded_map, mask_map, ) z_sample, weights = latent_distribution.sample(n_samples=n_samples) mask_z = mask_x.any(-1) y_sample = self.decoder( z_sample, mask_z, x, mask_x, encoded_absolute, encoded_map, mask_map, offset ) return y_sample, weights, latent_distribution def decode( self, z_samples: torch.Tensor, mask_z: torch.Tensor, x: torch.Tensor, mask_x: torch.Tensor, map: torch.Tensor, mask_map: torch.Tensor, offset: torch.Tensor, ): """Returns predicted y values conditionned on z_samples and the other observations. Args: z_samples: (batch_size, num_agents, (n_samples), latent_dim) tensor of latent samples mask_z: (batch_size, num_agents) bool mask x: (batch_size, num_agents, num_steps, state_dim) tensor of history mask_x: (batch_size, num_agents, num_steps) tensor of bool mask map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects mask_map: (batch_size, num_objects, object_sequence_length) tensor True where map features are good False where it is padding offset : (batch_size, num_agents, state_dim) offset position from ego. """ encoded_map = self.map_encoder(map, mask_map) mask_map = mask_map.any(-1) encoded_absolute = self.absolute_encoder(offset) return self.decoder( z_samples=z_samples, mask_z=mask_z, x=x, mask_x=mask_x, encoded_absolute=encoded_absolute, encoded_map=encoded_map, mask_map=mask_map, offset=offset, ) class TrainingBiasedCVAE(InferenceBiasedCVAE): """CVAE with a biased encoder module for risk-biased trajectory forecasting. This module is as a non-sampling-based version of BiasedLatentCVAE. Args: absolute_encoder: encoder model for the absolute positions of the agents map_encoder: encoder model for map objects biased_encoder: biased encoder that uses past and auxiliary input, inference_encoder: inference encoder that uses only past, decoder: CVAE decoder model future_encoder: training encoder that uses past and future, cost_function: cost function used to compute the risk objective risk_estimator: risk estimator used to compute the risk objective prior_distribution: prior distribution for the latent space. training_mode (optional): set to "cvae" to train the unbiased model, set to "bias" to train the biased encoder. Defaults to "cvae". latent_regularization (optional): regularization term for the latent space. Defaults to 0. risk_assymetry_factor (optional): risk asymmetry factor used to compute the risk objective avoiding underestimations. """ def __init__( self, absolute_encoder: MLP, map_encoder: MapEncoderNN, biased_encoder: CVAEEncoder, inference_encoder: CVAEEncoder, decoder: CVAEAccelerationDecoder, future_encoder: CVAEEncoder, cost_function: BaseCostTorch, risk_estimator: AbstractMonteCarloRiskEstimator, prior_distribution: AbstractLatentDistribution, training_mode: str = "cvae", latent_regularization: float = 0.0, risk_assymetry_factor: float = 100.0, ) -> None: super().__init__( absolute_encoder, map_encoder, biased_encoder, inference_encoder, decoder, prior_distribution, ) self.future_encoder = future_encoder self._cost = cost_function self._risk = risk_estimator self.set_training_mode(training_mode) self.regularization_factor = latent_regularization self.risk_assymetry_factor = risk_assymetry_factor def cvae_parameters(self, recurse: bool = True): yield from super().cvae_parameters(recurse) yield from self.future_encoder.parameters(recurse) def get_parameters(self, recurse: bool = True): """Returns a list of two parameter iterators: cvae and encoder only.""" return [ self.cvae_parameters(recurse), self.biased_parameters(recurse), ] def set_training_mode(self, training_mode: str) -> None: """ Change the training mode (get_loss function will be different depending on the mode). Warning: This does not freeze the decoder because the gradient must pass through it. The decoder should be frozen at the optimizer level when changing mode. """ assert training_mode in ["cvae", "bias"] self.training_mode = training_mode if training_mode == "cvae": self.get_loss = self.get_loss_cvae else: self.get_loss = self.get_loss_biased def forward_future( self, x: torch.Tensor, mask_x: torch.Tensor, map: torch.Tensor, mask_map: torch.Tensor, y: torch.Tensor, mask_y: torch.Tensor, offset: torch.Tensor, return_inference: bool = False, ) -> Union[ Tuple[torch.Tensor, AbstractLatentDistribution], Tuple[torch.Tensor, AbstractLatentDistribution, AbstractLatentDistribution], ]: """Forward function that outputs a noisy reconstruction of y and parameters of latent posterior distribution Args: x: (batch_size, num_agents, num_steps, state_dim) tensor of history mask_x: (batch_size, num_agents, num_steps) tensor of bool mask map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects mask_map: (batch_size, num_objects, object_sequence_length) tensor of bool mask y: (batch_size, num_agents, num_steps_future, state_dim) tensor of future trajectory. mask_y: (batch_size, num_agents, num_steps_future) tensor of bool mask. offset: (batch_size, num_agents, state_dim) offset position from ego. return_inference: (optional) Set to true if z_mean_inference and z_log_std_inference should be returned, Defaults to None. Returns: noisy reconstruction y of size (batch_size, num_agents, num_steps_future, state_dim), and the distribution of the latent posterior, as well as, optionally, the distribution of the latent inference posterior. """ encoded_map = self.map_encoder(map, mask_map) mask_map = mask_map.any(-1) encoded_absolute = self.absolute_encoder(offset) latent_distribution = self.future_encoder( x, mask_x, y=y, mask_y=mask_y, encoded_absolute=encoded_absolute, encoded_map=encoded_map, mask_map=mask_map, ) z_sample, weights = latent_distribution.sample() mask_z = mask_x.any(-1) y_sample = self.decoder( z_sample, mask_z, x, mask_x, encoded_absolute, encoded_map, mask_map, offset, ) if return_inference: inference_distribution = self.inference_encoder( x, mask_x, encoded_absolute, encoded_map, mask_map, ) return ( y_sample, latent_distribution, inference_distribution, ) else: return y_sample, latent_distribution def get_loss_cvae( self, x: torch.Tensor, mask_x: torch.Tensor, map: torch.Tensor, mask_map: torch.Tensor, y: torch.Tensor, *, mask_y: torch.Tensor, mask_loss: torch.Tensor, offset: torch.Tensor, unnormalizer: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], kl_weight: float, kl_threshold: float, **kwargs, ) -> Tuple[torch.Tensor, dict]: """Compute and return risk-biased CVAE loss averaged over batch and sequence time steps, along with desired loss-related metrics for logging Args: x: (batch_size, num_agents, num_steps, state_dim) tensor of history mask_x: (batch_size, num_agents, num_steps) tensor of bool mask map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects mask_map: (batch_size, num_objects, object_sequence_length) tensor True where map features are good False where it is padding y: (batch_size, num_agents, num_steps_future, state_dim) tensor of future trajectory. mask_y: (batch_size, num_agents, num_steps_future) tensor of bool mask. mask_loss: (batch_size, num_agents, num_steps_future) tensor of bool mask set to True where the loss should be computed and to False where it shouldn't offset : (batch_size, num_agents, state_dim) offset position from ego. unnormalizer: function that takes in a trajectory and an offset and that outputs the unnormalized trajectory kl_weight: weight to apply to the KL loss (normal value is 1.0, larger values can be used for disentanglement) kl_threshold: minimum float value threshold applied to the KL loss Returns: torch.Tensor: (1,) loss tensor dict: dict that contains loss-related metrics to be logged """ log_dict = dict() if not mask_loss.any(): warn("A batch is dropped because the whole loss is masked.") return torch.zeros(1, requires_grad=True), {} mask_z = mask_x.any(-1) # sum_mask_z = mask_z.float().sum().clamp_min(1) (y_sample, latent_distribution, inference_distribution) = self.forward_future( x, mask_x, map, mask_map, y, mask_y, offset, return_inference=True, ) # sum_mask_z *=[-1] # log_dict["latent/abs_mean"] = ( # ( * mask_z.unsqueeze(-1).float()).sum() / sum_mask_z # ).item() # log_dict["latent/std"] = ( # (latent_distribution.logvar.exp() * mask_z.unsqueeze(-1).float()).sum() / sum_mask_z # ).item() log_dict["fde/encoded"] = FDE( unnormalizer(y_sample, offset), unnormalizer(y, offset), mask_loss ).item() rec_loss = reconstruction_loss(y_sample, y, mask_loss) kl_loss = latent_distribution.kl_loss( inference_distribution, kl_threshold, mask_z, ) # kl_loss_prior = latent_distribution.kl_loss( self.prior_distribution, kl_threshold, mask_z, ) sampling_loss = latent_distribution.sampling_loss() log_dict["loss/rec"] = rec_loss.item() log_dict["loss/kl"] = kl_loss.item() log_dict["loss/kl_prior"] = kl_loss_prior.item() log_dict["loss/sampling"] = sampling_loss.item() log_dict.update(latent_distribution.log_dict("future")) log_dict.update(inference_distribution.log_dict("inference")) loss = ( rec_loss + kl_weight * kl_loss + self.regularization_factor * kl_loss_prior + sampling_loss ) log_dict["loss/total"] = loss.item() return loss, log_dict def get_loss_biased( self, x: torch.Tensor, mask_x: torch.Tensor, map: torch.Tensor, mask_map: torch.Tensor, y: torch.Tensor, *, mask_loss: torch.Tensor, offset: torch.Tensor, unnormalizer: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], risk_level: torch.Tensor, x_ego: torch.Tensor, y_ego: torch.Tensor, kl_weight: float, kl_threshold: float, risk_weight: float, n_samples_risk: int, n_samples_biased: int, dt: float, **kwargs, ) -> Tuple[torch.Tensor, dict]: """Compute and return risk-biased CVAE loss averaged over batch and sequence time steps, along with desired loss-related metrics for logging Args: x: (batch_size, num_agents, num_steps, state_dim) tensor of history mask_x: (batch_size, num_agents, num_steps) tensor of bool mask map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects mask_map: (batch_size, num_objects, object_sequence_length) tensor True where map features are good False where it is padding y: (batch_size, num_agents, num_steps_future, state_dim) tensor of future trajectory. mask_loss: (batch_size, num_agents, num_steps_future) tensor of bool mask set to True where the loss should be computed and to False where it shouldn't offset : (batch_size, num_agents, state_dim) offset position from ego. unnormalizer: function that takes in a trajectory and an offset and that outputs the unnormalized trajectory risk_level: (batch_size, num_agents) tensor of risk levels desired for future trajectories x_ego: (batch_size, 1, num_steps, state_dim) tensor of ego history y_ego: (batch_size, 1, num_steps_future, state_dim) tensor of ego future trajectory kl_weight: weight to apply to the KL loss (normal value is 1.0, larger values can be used for disentanglement) kl_threshold: minimum float value threshold applied to the KL loss risk_weight: weight to apply to the risk loss (beta parameter in our document) n_samples_risk: number of sample to use for Monte-Carlo estimation of the risk using the unbiased distribution n_samples_biased: number of sample to use for Monte-Carlo estimation of the risk using the biased distribution dt: time step in trajectories Returns: torch.Tensor: (1,) loss tensor dict: dict that contains loss-related metrics to be logged """ log_dict = dict() if not mask_loss.any(): warn("A batch is dropped because the whole loss is masked.") return torch.zeros(1, requires_grad=True), {} mask_z = mask_x.any(-1) # Computing unbiased samples n_samples_risk = max(1, n_samples_risk) n_samples_biased = max(1, n_samples_biased) cost = [] weights = [] pack_size = min(n_samples_risk, n_samples_biased) with torch.no_grad(): encoded_map = self.map_encoder(map, mask_map) mask_map = mask_map.any(-1) encoded_absolute = self.absolute_encoder(offset) inference_distribution = self.inference_encoder( x, mask_x, encoded_absolute, encoded_map, mask_map, ) for _ in range(n_samples_risk // pack_size): z_samples, w = inference_distribution.sample( n_samples=pack_size, ) y_samples = self.decoder( z_samples=z_samples, mask_z=mask_z, x=x, mask_x=mask_x, encoded_absolute=encoded_absolute, encoded_map=encoded_map, mask_map=mask_map, offset=offset, ) mask_loss_samples = repeat(mask_loss, "b a t -> b a s t", s=pack_size) # Computing unbiased cost cost.append( get_cost( self._cost, x, y_samples, offset, x_ego, y_ego, dt, unnormalizer, mask_loss_samples, ) ) weights.append(w) cost =, 2) weights =, 2) risk_cost = self._risk(risk_level, cost, weights) log_dict["fde/prior"] = FDE( unnormalizer(y_samples, offset), unnormalizer(y, offset).unsqueeze(-3), mask_loss_samples, ).item() mask_cost_samples = repeat(mask_z, "b a -> b a s", s=n_samples_risk) mean_cost = (cost * mask_cost_samples.float() * weights).sum(2) / ( (mask_cost_samples.float() * weights).sum(2).clamp_min(1) ) log_dict["cost/mean"] = ( (mean_cost * mask_loss.any(-1).float()).sum() / (mask_loss.any(-1).float().sum()) ).item() # Computing biased latent parameters biased_distribution = self.biased_encoder( x, mask_x, encoded_absolute.detach(), encoded_map.detach(), mask_map, risk_level=risk_level, x_ego=x_ego, y_ego=y_ego, offset=offset, ) biased_distribution = inference_distribution.average( biased_distribution, risk_level.unsqueeze(-1) ) # sum_mask_z = mask_z.float().sum().clamp_min(1)*[-1] # log_dict["latent/abs_mean_biased"] = ( # ( * mask_z.unsqueeze(-1).float()).sum() / sum_mask_z # ).item() # log_dict["latent/var_biased"] = ( # (biased_distribution.logvar.exp() * mask_z.unsqueeze(-1).float()).sum() / sum_mask_z # ).item() # Computing biased samples z_biased_samples, weights = biased_distribution.sample( n_samples=n_samples_biased ) mask_z_samples = repeat(mask_z, "b a -> b a s ()", s=n_samples_biased) log_dict["latent/abs_samples_biased"] = ( (z_biased_samples.abs() * mask_z_samples.float()).sum() / (mask_z_samples.float().sum()) ).item() y_biased_samples = self.decoder( z_samples=z_biased_samples, mask_z=mask_z, x=x, mask_x=mask_x, encoded_absolute=encoded_absolute, encoded_map=encoded_map, mask_map=mask_map, offset=offset, ) log_dict["fde/prior_biased"] = FDE( unnormalizer(y_biased_samples, offset), unnormalizer(y, offset).unsqueeze(2), mask_loss=mask_loss_samples, ).item() # Computing biased cost biased_cost = get_cost( self._cost, x, y_biased_samples, offset, x_ego, y_ego, dt, unnormalizer, mask_loss_samples, ) mask_cost_samples = mask_z_samples.squeeze(-1) mean_biased_cost = (biased_cost * mask_cost_samples.float() * weights).sum( 2 ) / ((mask_cost_samples.float() * weights).sum(2).clamp_min(1)) log_dict["cost/mean_biased"] = ( (mean_biased_cost * mask_loss.any(-1).float()).sum() / (mask_loss.any(-1).float().sum()) ).item() log_dict["cost/risk"] = ( (risk_cost * mask_loss.any(-1).float()).sum() / (mask_loss.any(-1).float().sum()) ).item() # Computing loss between risk and biased cost risk_loss = risk_loss_function( mean_biased_cost, risk_cost.detach(), mask_loss.any(-1), self.risk_assymetry_factor, ) log_dict["loss/risk"] = risk_loss.item() # Computing KL loss between prior and biased latent kl_loss = inference_distribution.kl_loss( biased_distribution, kl_threshold, mask_z=mask_z, ) log_dict["loss/kl"] = kl_loss.item() loss = risk_weight * risk_loss + kl_weight * kl_loss log_dict["loss/total"] = loss.item() log_dict["loss/risk_weight"] = risk_weight log_dict.update(inference_distribution.log_dict("inference")) log_dict.update(biased_distribution.log_dict("biased")) return loss, log_dict def get_prediction_accuracy( self, x: torch.Tensor, mask_x: torch.Tensor, map: torch.Tensor, mask_map: torch.Tensor, y: torch.Tensor, mask_loss: torch.Tensor, x_ego: torch.Tensor, y_ego: torch.Tensor, offset: torch.Tensor, unnormalizer: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], risk_level: torch.Tensor, num_samples_min_fde: int = 0, ) -> dict: """ A function that calls the predict method and returns a dict that contains prediction metrics, which measure accuracy with respect to ground-truth future trajectory y Args: x: (batch_size, num_agents, num_steps, state_dim) tensor of history mask_x: (batch_size, num_agents, num_steps) tensor of bool mask map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects mask_map: (batch_size, num_objects, object_sequence_length) tensor True where map features are good False where it is padding y: (batch_size, num_agents, num_steps_future, state_dim) tensor of future trajectory. mask_loss: (batch_size, num_agents, num_steps_future) tensor of bool mask set to True where the loss should be computed and to False where it shouldn't x_ego: (batch_size, 1, num_steps, state_dim) tensor of ego history y_ego: (batch_size, 1, num_steps_future, state_dim) tensor of ego future trajectory offset: (batch_size, num_agents, state_dim) offset position from ego unnormalizer: function that takes in a trajectory and an offset and that outputs the unnormalized trajectory risk_level: (batch_size, num_agents) tensor of risk levels desired for future trajectories num_samples_min_fde: number of samples to use when computing the minimum final displacement error Returns: dict: dict that contains prediction-related metrics to be logged """ log_dict = dict() with torch.no_grad(): batch_size = x.shape[0] beg = 0 y_predict = [] # Limit the batch size so the num_samples_min_fde value does not impact the memory usage for i in range(batch_size // num_samples_min_fde + 1): sub_batch_size = num_samples_min_fde end = beg + sub_batch_size y_predict.append( unnormalizer( self.forward( x=x[beg:end], mask_x=mask_x[beg:end], map=map[beg:end], mask_map=mask_map[beg:end], offset=offset[beg:end], x_ego=x_ego[beg:end], y_ego=y_ego[beg:end], risk_level=None, n_samples=num_samples_min_fde, )[0], offset[beg:end], ) ) beg = end if beg >= batch_size: break # Limit the batch size so the num_samples_min_fde value does not impact the memory usage if risk_level is not None: y_predict_biased = [] beg = 0 for i in range(batch_size // num_samples_min_fde + 1): sub_batch_size = num_samples_min_fde end = beg + sub_batch_size y_predict_biased.append( unnormalizer( self.forward( x=x[beg:end], mask_x=mask_x[beg:end], map=map[beg:end], mask_map=mask_map[beg:end], offset=offset[beg:end], x_ego=x_ego[beg:end], y_ego=y_ego[beg:end], risk_level=risk_level[beg:end], n_samples=num_samples_min_fde, )[0], offset[beg:end], ) ) beg = end if beg >= batch_size: break y_predict_biased =, 0) if num_samples_min_fde > 0: repeated_mask_loss = repeat( mask_loss, "b a t -> b a samples t", samples=num_samples_min_fde ) log_dict["fde/prior_biased"] = FDE( y_predict_biased, y.unsqueeze(-3), mask_loss=repeated_mask_loss ).item() log_dict["minfde/prior_biased"] = minFDE( y_predict_biased, y.unsqueeze(-3), mask_loss=repeated_mask_loss ).item() else: log_dict["fde/prior_biased"] = FDE( y_predict_biased, y, mask_loss=mask_loss ).item() y_predict =, 0) y_unnormalized = unnormalizer(y, offset) if num_samples_min_fde > 0: repeated_mask_loss = repeat( mask_loss, "b a t -> b a samples t", samples=num_samples_min_fde ) log_dict["fde/prior"] = FDE( y_predict, y_unnormalized.unsqueeze(-3), mask_loss=repeated_mask_loss ).item() log_dict["minfde/prior"] = minFDE( y_predict, y_unnormalized.unsqueeze(-3), mask_loss=repeated_mask_loss ).item() else: log_dict["fde/prior"] = FDE( y_predict, y_unnormalized, mask_loss=mask_loss ).item() return log_dict def cvae_factory( params: CVAEParams, cost_function: BaseCostTorch, risk_estimator: AbstractMonteCarloRiskEstimator, training_mode: str = "cvae", ): """Biased CVAE with a biased MLP encoder and an MLP decoder Args: params: dataclass defining the necessary parameters cost_function: cost function used to compute the risk objective risk_estimator: risk estimator used to compute the risk objective training_mode: "inference", "cvae" or "bias" set what is the training mode latent_distribution: "gaussian" or "quantized" set the latent distribution """ absolute_encoder_nn = MLP( params.dynamic_state_dim, params.hidden_dim, params.hidden_dim, params.num_hidden_layers, params.is_mlp_residual, ) map_encoder_nn = MapEncoderNN(params) if params.latent_distribution == "gaussian": latent_distribution_creator = GaussianLatentDistribution prior_distribution = GaussianLatentDistribution( torch.zeros(1, 1, 2 * params.latent_dim) ) future_encoder_latent_dim = 2 * params.latent_dim inference_encoder_latent_dim = 2 * params.latent_dim biased_encoder_latent_dim = 2 * params.latent_dim elif params.latent_distribution == "quantized": latent_distribution_creator = QuantizedDistributionCreator( params.latent_dim, params.num_vq ) prior_distribution = latent_distribution_creator( torch.zeros(1, 1, params.num_vq) ) future_encoder_latent_dim = params.latent_dim inference_encoder_latent_dim = params.num_vq biased_encoder_latent_dim = params.num_vq biased_encoder_nn = BiasedEncoderNN( params, biased_encoder_latent_dim, num_steps=params.num_steps, ) biased_encoder = CVAEEncoder( biased_encoder_nn, latent_distribution_creator=latent_distribution_creator ) future_encoder_nn = FutureEncoderNN( params, future_encoder_latent_dim, params.num_steps + params.num_steps_future ) future_encoder = CVAEEncoder( future_encoder_nn, latent_distribution_creator=latent_distribution_creator ) inference_encoder_nn = InferenceEncoderNN( params, inference_encoder_latent_dim, params.num_steps ) inference_encoder = CVAEEncoder( inference_encoder_nn, latent_distribution_creator=latent_distribution_creator ) decoder_nn = DecoderNN(params) decoder = CVAEAccelerationDecoder(decoder_nn) # decoder = CVAEParametrizedDecoder(decoder_nn) if training_mode == "inference": cvae = InferenceBiasedCVAE( absolute_encoder_nn, map_encoder_nn, biased_encoder, inference_encoder, decoder, prior_distribution=prior_distribution, ) cvae.eval() return cvae else: return TrainingBiasedCVAE( absolute_encoder_nn, map_encoder_nn, biased_encoder, inference_encoder, decoder, future_encoder=future_encoder, cost_function=cost_function, risk_estimator=risk_estimator, training_mode=training_mode, latent_regularization=params.latent_regularization, risk_assymetry_factor=params.risk_assymetry_factor, prior_distribution=prior_distribution, )