Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import Literal | |
import torch | |
import torch.nn as nn | |
class WMDetectionLoss(nn.Module): | |
"""Compute the detection loss""" | |
def __init__(self, p_weight: float = 1.0, n_weight: float = 1.0) -> None: | |
super().__init__() | |
self.criterion = nn.NLLLoss() | |
self.p_weight = p_weight | |
self.n_weight = n_weight | |
def forward(self, positive, negative, mask, message=None): | |
positive = positive[:, :2, :] # b 2+nbits t -> b 2 t | |
negative = negative[:, :2, :] # b 2+nbits t -> b 2 t | |
# dimensionality of positive [bsz, classes=2, time_steps] | |
# correct classes for pos = [bsz, time_steps] where all values = 1 for positive | |
classes_shape = positive[ | |
:, 0, : | |
] # same as positive or negative but dropping dim=1 | |
pos_correct_classes = torch.ones_like(classes_shape, dtype=int) | |
neg_correct_classes = torch.zeros_like(classes_shape, dtype=int) | |
# taking log because network outputs softmax | |
# NLLLoss expects a logsoftmax input | |
positive = torch.log(positive) | |
negative = torch.log(negative) | |
if not torch.all(mask == 1): | |
# pos_correct_classes [bsz, timesteps] mask [bsz, 1, timesptes] | |
# mask is applied to the watermark, this basically flips the tgt class from 1 (positive) | |
# to 0 (negative) in the correct places | |
pos_correct_classes = pos_correct_classes * mask[:, 0, :].to(int) | |
loss_p = self.p_weight * self.criterion(positive, pos_correct_classes) | |
# no need for negative class loss here since some of the watermark | |
# is masked to negative | |
return loss_p | |
else: | |
loss_p = self.p_weight * self.criterion(positive, pos_correct_classes) | |
loss_n = self.n_weight * self.criterion(negative, neg_correct_classes) | |
return loss_p + loss_n | |
class WMMbLoss(nn.Module): | |
def __init__(self, temperature: float, loss_type: Literal["bce", "mse"]) -> None: | |
""" | |
Compute the masked sample-level detection loss | |
(https://arxiv.org/pdf/2401.17264) | |
Args: | |
temperature: temperature for loss computation | |
loss_type: bce or mse between outputs and original message | |
""" | |
super().__init__() | |
self.bce_with_logits = ( | |
nn.BCEWithLogitsLoss() | |
) # same as Softmax + NLLLoss, but when only 1 output unit | |
self.mse = nn.MSELoss() | |
self.loss_type = loss_type | |
self.temperature = temperature | |
def forward(self, positive, negative, mask, message): | |
""" | |
Compute decoding loss | |
Args: | |
positive: outputs on watermarked samples [bsz, 2+nbits, time_steps] | |
negative: outputs on not watermarked samples [bsz, 2+nbits, time_steps] | |
mask: watermark mask [bsz, 1, time_steps] | |
message: original message [bsz, nbits] or None | |
""" | |
# # no use of negative at the moment | |
# negative = negative[:, 2:, :] # b 2+nbits t -> b nbits t | |
# negative = torch.masked_select(negative, mask) | |
if message.size(0) == 0: | |
return torch.tensor(0.0) | |
positive = positive[:, 2:, :] # b 2+nbits t -> b nbits t | |
assert ( | |
positive.shape[-2] == message.shape[1] | |
), "in decoding loss: \ | |
enc and dec don't share nbits, are you using multi-bit?" | |
# cut last dim of positive to keep only where mask is 1 | |
new_shape = [*positive.shape[:-1], -1] # b nbits -1 | |
positive = torch.masked_select(positive, mask == 1).reshape(new_shape) | |
message = message.unsqueeze(-1).repeat(1, 1, positive.shape[2]) # b k -> b k t | |
if self.loss_type == "bce": | |
# in this case similar to temperature in softmax | |
loss = self.bce_with_logits(positive / self.temperature, message.float()) | |
elif self.loss_type == "mse": | |
loss = self.mse(positive / self.temperature, message.float()) | |
return loss | |