Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,560 Bytes
9d0d223 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
from audiocraft.models.loaders import load_audioseal_models
class WMModel(ABC, nn.Module):
"""
A wrapper interface to different watermarking models for
training or evaluation purporses
"""
@abstractmethod
def get_watermark(
self,
x: torch.Tensor,
message: tp.Optional[torch.Tensor] = None,
sample_rate: int = 16_000,
) -> torch.Tensor:
"""Get the watermark from an audio tensor and a message.
If the input message is None, a random message of
n bits {0,1} will be generated
"""
@abstractmethod
def detect_watermark(self, x: torch.Tensor) -> torch.Tensor:
"""Detect the watermarks from the audio signal
Args:
x: Audio signal, size batch x frames
Returns:
tensor of size (B, 2+n, frames) where:
Detection results of shape (B, 2, frames)
Message decoding results of shape (B, n, frames)
"""
class AudioSeal(WMModel):
"""Wrap Audioseal (https://github.com/facebookresearch/audioseal) for the
training and evaluation. The generator and detector are jointly trained
"""
def __init__(
self,
generator: nn.Module,
detector: nn.Module,
nbits: int = 0,
):
super().__init__()
self.generator = generator # type: ignore
self.detector = detector # type: ignore
# Allow to re-train an n-bit model with new 0-bit message
self.nbits = nbits if nbits else self.generator.msg_processor.nbits
def get_watermark(
self,
x: torch.Tensor,
message: tp.Optional[torch.Tensor] = None,
sample_rate: int = 16_000,
) -> torch.Tensor:
return self.generator.get_watermark(x, message=message, sample_rate=sample_rate)
def detect_watermark(self, x: torch.Tensor) -> torch.Tensor:
"""
Detect the watermarks from the audio signal. The first two units of the output
are used for detection, the rest is used to decode the message. If the audio is
not watermarked, the message will be random.
Args:
x: Audio signal, size batch x frames
Returns
torch.Tensor: Detection + decoding results of shape (B, 2+nbits, T).
"""
# Getting the direct decoded message from the detector
result = self.detector.detector(x) # b x 2+nbits
# hardcode softmax on 2 first units used for detection
result[:, :2, :] = torch.softmax(result[:, :2, :], dim=1)
return result
def forward( # generator
self,
x: torch.Tensor,
message: tp.Optional[torch.Tensor] = None,
sample_rate: int = 16_000,
alpha: float = 1.0,
) -> torch.Tensor:
"""Apply the watermarking to the audio signal x with a tune-down ratio (default 1.0)"""
wm = self.get_watermark(x, message)
return x + alpha * wm
@staticmethod
def get_pretrained(name="base", device=None) -> WMModel:
if device is None:
if torch.cuda.device_count():
device = "cuda"
else:
device = "cpu"
return load_audioseal_models("facebook/audioseal", filename=name, device=device)
|