Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
import torch | |
from torch.nn.modules.loss import _Loss | |
class LatentLayersKLLoss(_Loss): | |
def __init__(self, args): | |
super().__init__() | |
self.args = args | |
def forward(self, layer_samples, lang_idx, update_num, sample_size): | |
prior = self.args.prior | |
samples = layer_samples[lang_idx] | |
eps = 1e-7 | |
if prior == "uniform": | |
# uniform prior | |
kl_loss = (samples * (torch.log(samples + eps) - math.log(0.5))).sum(-1) | |
elif prior == "agged_posterior": | |
# aggregated posterior | |
y_t = torch.stack([x.detach() for x in layer_samples], dim=0) | |
agged_q = torch.sum(y_t, dim=0) | |
row_norm = agged_q.sum(-1) | |
normed_agg_q = agged_q / row_norm | |
kl_loss = ( | |
samples * (torch.log(samples + eps) - torch.log(normed_agg_q + eps)) | |
).sum(-1) | |
else: | |
raise NotImplementedError("The specified prior is not implemented.") | |
# normalized by number of layers | |
kl_loss /= layer_samples[0].size()[0] | |
kl_weight = min( | |
self.args.sparsity_weight, | |
(update_num - self.args.soft_update) | |
* self.args.sparsity_weight | |
/ self.args.anneal_updates, | |
) | |
kl_loss *= kl_weight * sample_size | |
return kl_loss | |
class LatentLayersSparsityLoss(_Loss): | |
def __init__(self, args): | |
super().__init__() | |
self.args = args | |
def is_valid(self, update_num): | |
if self.args.target_layers <= 0: | |
return False | |
return update_num > (self.args.soft_update + self.args.anneal_updates) | |
def forward(self, layer_samples_list, update_num, sample_size): | |
batch_loss = 0 | |
share_loss = 0 | |
global_sparsity_loss = 0 | |
layer_samples = torch.stack(layer_samples_list, dim=0) | |
if ( | |
self.args.target_layers > 0 or self.args.share_weight > 0 | |
) and update_num > (self.args.soft_update + self.args.anneal_updates): | |
# anneal sparsity weight | |
if update_num < (self.args.anneal_updates + self.args.soft_update): | |
weight_anneal = 0 | |
elif update_num < (2 * self.args.anneal_updates + self.args.soft_update): | |
weight_anneal = ( | |
(update_num - self.args.soft_update - self.args.anneal_updates) | |
* self.args.share_weight | |
/ self.args.anneal_updates | |
) | |
else: | |
weight_anneal = 1 | |
# compute ratio among languages | |
layer_utilization = torch.sum(layer_samples, dim=0) | |
layer_utilization /= layer_samples.size()[0] | |
if self.args.share_weight > 0: | |
# encouraging sharing across languages | |
share_loss = sum( | |
-1.0 * v * math.log(v) for v in layer_utilization if v > 0 | |
) | |
batch_loss += ( | |
weight_anneal * self.args.share_weight * sample_size * share_loss | |
) | |
if self.args.target_layers > 0: | |
# computed expected number of layers selected | |
expeted_layers = sum(layer_utilization) | |
# compute l2 loss wrt target number of layers | |
global_sparsity_loss = (expeted_layers - self.args.target_layers) ** 2 | |
batch_loss += ( | |
weight_anneal | |
* self.args.share_weight | |
* sample_size | |
* global_sparsity_loss | |
) | |
return batch_loss | |