File size: 4,249 Bytes
9d0d223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Literal

import torch
import torch.nn as nn


class WMDetectionLoss(nn.Module):
    """Compute the detection loss"""
    def __init__(self, p_weight: float = 1.0, n_weight: float = 1.0) -> None:
        super().__init__()
        self.criterion = nn.NLLLoss()
        self.p_weight = p_weight
        self.n_weight = n_weight

    def forward(self, positive, negative, mask, message=None):

        positive = positive[:, :2, :]  # b 2+nbits t -> b 2 t
        negative = negative[:, :2, :]  # b 2+nbits t -> b 2 t

        # dimensionality of positive [bsz, classes=2, time_steps]
        # correct classes for pos = [bsz, time_steps] where all values = 1 for positive
        classes_shape = positive[
            :, 0, :
        ]  # same as positive or negative but dropping dim=1
        pos_correct_classes = torch.ones_like(classes_shape, dtype=int)
        neg_correct_classes = torch.zeros_like(classes_shape, dtype=int)

        # taking log because network outputs softmax
        # NLLLoss expects a logsoftmax input
        positive = torch.log(positive)
        negative = torch.log(negative)

        if not torch.all(mask == 1):
            # pos_correct_classes [bsz, timesteps] mask [bsz, 1, timesptes]
            # mask is applied to the watermark, this basically flips the tgt class from 1 (positive)
            # to 0 (negative) in the correct places
            pos_correct_classes = pos_correct_classes * mask[:, 0, :].to(int)
            loss_p = self.p_weight * self.criterion(positive, pos_correct_classes)
            # no need for negative class loss here since some of the watermark
            # is masked to negative
            return loss_p

        else:
            loss_p = self.p_weight * self.criterion(positive, pos_correct_classes)
            loss_n = self.n_weight * self.criterion(negative, neg_correct_classes)
            return loss_p + loss_n


class WMMbLoss(nn.Module):
    def __init__(self, temperature: float, loss_type: Literal["bce", "mse"]) -> None:
        """
        Compute the masked sample-level detection loss
        (https://arxiv.org/pdf/2401.17264)

        Args:
            temperature: temperature for loss computation
            loss_type: bce or mse between outputs and original message
        """
        super().__init__()
        self.bce_with_logits = (
            nn.BCEWithLogitsLoss()
        )  # same as Softmax + NLLLoss, but when only 1 output unit
        self.mse = nn.MSELoss()
        self.loss_type = loss_type
        self.temperature = temperature

    def forward(self, positive, negative, mask, message):
        """
        Compute decoding loss
        Args:
            positive: outputs on watermarked samples [bsz, 2+nbits, time_steps]
            negative: outputs on not watermarked samples [bsz, 2+nbits, time_steps]
            mask: watermark mask [bsz, 1, time_steps]
            message: original message [bsz, nbits] or None
        """
        # # no use of negative at the moment
        # negative = negative[:, 2:, :]  # b 2+nbits t -> b nbits t
        # negative = torch.masked_select(negative, mask)
        if message.size(0) == 0:
            return torch.tensor(0.0)
        positive = positive[:, 2:, :]  # b 2+nbits t -> b nbits t
        assert (
            positive.shape[-2] == message.shape[1]
        ), "in decoding loss: \
            enc and dec don't share nbits, are you using multi-bit?"

        # cut last dim of positive to keep only where mask is 1
        new_shape = [*positive.shape[:-1], -1]  # b nbits -1
        positive = torch.masked_select(positive, mask == 1).reshape(new_shape)

        message = message.unsqueeze(-1).repeat(1, 1, positive.shape[2])  # b k -> b k t
        if self.loss_type == "bce":
            # in this case similar to temperature in softmax
            loss = self.bce_with_logits(positive / self.temperature, message.float())
        elif self.loss_type == "mse":
            loss = self.mse(positive / self.temperature, message.float())

        return loss