my commited on
Commit
32ca76b
1 Parent(s): 84c0c04

Add application file

Browse files
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==1.13.1
2
+ torchvision==0.14.1
3
+ torchaudio==0.13.1
.gitignore ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+
3
+ #idea
4
+ .idea
5
+ wandb/
6
+ temp/
7
+ data/
8
+
9
+ # Byte-compiled / optimized / DLL files
10
+ __pycache__/
11
+ *.py[cod]
12
+ *$py.class
13
+
14
+ # C extensions
15
+ *.so
16
+
17
+ # Distribution / packaging
18
+ .Python
19
+ build/
20
+ develop-eggs/
21
+ dist/
22
+ downloads/
23
+ eggs/
24
+ .eggs/
25
+ lib/
26
+ lib64/
27
+ parts/
28
+ sdist/
29
+ var/
30
+ wheels/
31
+ share/python-wheels/
32
+ *.egg-info/
33
+ .installed.cfg
34
+ *.egg
35
+ MANIFEST
36
+
37
+ # PyInstaller
38
+ # Usually these files are written by a python script from a template
39
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
40
+ *.manifest
41
+ *.spec
42
+
43
+ # Installer logs
44
+ pip-log.txt
45
+ pip-delete-this-directory.txt
46
+
47
+ # Unit test / coverage reports
48
+ htmlcov/
49
+ .tox/
50
+ .nox/
51
+ .coverage
52
+ .coverage.*
53
+ .cache
54
+ nosetests.xml
55
+ coverage.xml
56
+ *.cover
57
+ *.py,cover
58
+ .hypothesis/
59
+ .pytest_cache/
60
+ cover/
61
+
62
+ # Translations
63
+ *.mo
64
+ *.pot
65
+
66
+ # Django stuff:
67
+ *.log
68
+ local_settings.py
69
+ db.sqlite3
70
+ db.sqlite3-journal
71
+
72
+ # Flask stuff:
73
+ instance/
74
+ .webassets-cache
75
+
76
+ # Scrapy stuff:
77
+ .scrapy
78
+
79
+ # Sphinx documentation
80
+ docs/_build/
81
+
82
+ # PyBuilder
83
+ .pybuilder/
84
+ target/
85
+
86
+ # Jupyter Notebook
87
+ .ipynb_checkpoints
88
+
89
+ # IPython
90
+ profile_default/
91
+ ipython_config.py
92
+
93
+ # pyenv
94
+ # For a library or package, you might want to ignore these files since the code is
95
+ # intended to run in multiple environments; otherwise, check them in:
96
+ # .python-version
97
+
98
+ # pipenv
99
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
101
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
102
+ # install all needed dependencies.
103
+ #Pipfile.lock
104
+
105
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
106
+ __pypackages__/
107
+
108
+ # Celery stuff
109
+ celerybeat-schedule
110
+ celerybeat.pid
111
+
112
+ # SageMath parsed files
113
+ *.sage.py
114
+
115
+ # Environments
116
+ .env
117
+ .venv
118
+ env/
119
+ venv/
120
+ ENV/
121
+ env.bak/
122
+ venv.bak/
123
+
124
+ # Spyder project settings
125
+ .spyderproject
126
+ .spyproject
127
+
128
+ # Rope project settings
129
+ .ropeproject
130
+
131
+ # mkdocs documentation
132
+ /site
133
+
134
+ # mypy
135
+ .mypy_cache/
136
+ .dmypy.json
137
+ dmypy.json
138
+
139
+ # Pyre type checker
140
+ .pyre/
141
+
142
+ # pytype static type analyzer
143
+ .pytype/
144
+
145
+ # Cython debug symbols
146
+ cython_debug/
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import time
3
+
4
+ import soundfile
5
+ import streamlit as st
6
+ import os
7
+ from utils import wm_add_v2, file_reader, model_util, wm_decode_v2, bin_util
8
+ from models import my_model_v7_recover
9
+ import torch
10
+ import uuid
11
+ import datetime
12
+ import numpy as np
13
+ from huggingface_hub import hf_hub_download, HfApi
14
+
15
+
16
+ # Function to add watermark to audio
17
+ def add_watermark(audio_path, watermark_text):
18
+ assert len(watermark_text) == 5
19
+
20
+ start_bit, msg_bit, watermark = wm_add_v2.create_parcel_message(len_start_bit, 32, watermark_text)
21
+
22
+ data, sr, audio_length_second = file_reader.read_as_single_channel_16k(audio_path, 16000)
23
+
24
+ _, signal_wmd, time_cost = wm_add_v2.add_watermark(watermark, data, 16000, 0.1, device, model)
25
+
26
+ tmp_file_name = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + "_" + str(uuid.uuid4()) + ".wav"
27
+ tmp_file_path = 'temp/' + tmp_file_name
28
+ soundfile.write(tmp_file_path, signal_wmd, sr)
29
+ return tmp_file_path
30
+
31
+
32
+ # Function to decode watermark from audio
33
+ def decode_watermark(audio_path):
34
+ data, sr, audio_length_second = file_reader.read_as_single_channel_16k(audio_path, 16000)
35
+ data = data[0:5 * sr]
36
+ start_bit = wm_add_v2.fix_pattern[0:len_start_bit]
37
+ support_count, mean_result, results = wm_decode_v2.extract_watermark_v2(
38
+ data,
39
+ start_bit,
40
+ 0.1,
41
+ 16000,
42
+ 0.3,
43
+ model,
44
+ device, "best")
45
+
46
+ if mean_result is None:
47
+ return "No Watermark"
48
+
49
+ payload = mean_result[len_start_bit:]
50
+ return bin_util.binArray2HexStr(payload)
51
+
52
+
53
+ # Main web app
54
+ def main():
55
+ if "def_value" not in st.session_state:
56
+ st.session_state.def_value = bin_util.binArray2HexStr(np.random.choice([0, 1], size=32 - len_start_bit))
57
+
58
+ st.title("Neural Audio Watermark")
59
+ st.write("Choose the action you want to perform:")
60
+
61
+ action = st.selectbox("Select Action", ["Add Watermark", "Decode Watermark"])
62
+
63
+ if action == "Add Watermark":
64
+ audio_file = st.file_uploader("Upload Audio File (WAV)", type=["wav"], accept_multiple_files=False)
65
+ if audio_file:
66
+ tmp_input_audio_file = os.path.join("temp", audio_file.name)
67
+ with open(tmp_input_audio_file, "wb") as f:
68
+ f.write(audio_file.getbuffer())
69
+ st.audio(tmp_input_audio_file, format="audio/wav")
70
+
71
+ watermark_text = st.text_input("Enter Watermark Text (5 English letters)", value=st.session_state.def_value)
72
+
73
+ add_watermark_button = st.button("Add Watermark", key="add_watermark_btn")
74
+ if add_watermark_button: # 点击按钮后执行的
75
+ if audio_file and watermark_text:
76
+ with st.spinner("Adding Watermark..."):
77
+ # add_watermark_button.empty()
78
+ # st.button("Add Watermark", disabled=True)
79
+ # st.button("Add Watermark", disabled=True, key="add_watermark_btn_disabled")
80
+ t1 = time.time()
81
+
82
+ watermarked_audio = add_watermark(tmp_input_audio_file, watermark_text)
83
+ encode_time_cost = time.time() - t1
84
+
85
+ st.write("Watermarked Audio:")
86
+ st.audio(watermarked_audio, format="audio/wav")
87
+ st.write("Time Cost:%d seconds" % encode_time_cost)
88
+
89
+ # st.button("Add Watermark", disabled=False)
90
+
91
+ elif action == "Decode Watermark":
92
+ audio_file = st.file_uploader("Upload Audio File (WAV/MP3)", type=["wav", "mp3"], accept_multiple_files=False)
93
+ if audio_file:
94
+ if st.button("Decode Watermark"):
95
+ # 1.保存
96
+ tmp_file_for_decode_path = os.path.join("temp", audio_file.name)
97
+ with open(tmp_file_for_decode_path, "wb") as f:
98
+ f.write(audio_file.getbuffer())
99
+
100
+ # 2.执行
101
+ with st.spinner("Decoding..."):
102
+ t1 = time.time()
103
+ decoded_watermark = decode_watermark(tmp_file_for_decode_path)
104
+ decode_cost = time.time() - t1
105
+
106
+ print("decoded_watermark", decoded_watermark)
107
+ # Display the decoded watermark
108
+ st.write("Decoded Watermark:", decoded_watermark)
109
+ st.write("Time Cost:%d seconds" % (decode_cost))
110
+
111
+
112
+ def load_model(resume_path):
113
+ n_fft = 1000
114
+ hop_length = 400
115
+ # https://huggingface.co/M4869/InvertibleWM/blob/main/step59000_snr39.99_pesq4.35_BERP_none0.30_mean1.81_std1.81.pkl
116
+ api_key = st.secrets["api_key"]
117
+ print(api_key, api_key)
118
+ model_ckpt_path = hf_hub_download(repo_id="M4869/InvertibleWM",
119
+ filename="step59000_snr39.99_pesq4.35_BERP_none0.30_mean1.81_std1.81.pkl",
120
+ token=api_key
121
+ )
122
+ # print("model_ckpt_path", model_ckpt_path)
123
+ resume_path = model_ckpt_path
124
+ # return
125
+
126
+ model = my_model_v7_recover.Model(16000, 32, n_fft, hop_length,
127
+ use_recover_layer=False, num_layers=8).to(device)
128
+ checkpoint = torch.load(resume_path, map_location=torch.device('cpu'))
129
+ state_dict = model_util.map_state_dict(checkpoint['model'])
130
+ model.load_state_dict(state_dict, strict=True)
131
+ model.eval()
132
+ return model
133
+
134
+
135
+ if __name__ == "__main__":
136
+ len_start_bit = 12
137
+
138
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
139
+
140
+ model = load_model("./data/step59000_snr39.99_pesq4.35_BERP_none0.30_mean1.81_std1.81.pkl")
141
+
142
+ main()
143
+ # decode_watermark("/Users/my/Downloads/7a95b353a46893903e9f946c24170b210ce14e8c52c63bb2ab3d144e.wav")
models/__init__.py ADDED
File without changes
models/hinet.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from models.invblock import INV_block
3
+
4
+
5
+ class Hinet(torch.nn.Module):
6
+
7
+ def __init__(self, in_channel=2, num_layers=16):
8
+ super(Hinet, self).__init__()
9
+ self.inv_blocks = torch.nn.ModuleList([INV_block(in_channel) for _ in range(num_layers)])
10
+
11
+ def forward(self, x1, x2, rev=False):
12
+ # x1:cover
13
+ # x2:secret
14
+ if not rev:
15
+ for inv_block in self.inv_blocks:
16
+ x1, x2 = inv_block(x1, x2)
17
+ else:
18
+ for inv_block in reversed(self.inv_blocks):
19
+ x1, x2 = inv_block(x1, x2, rev=True)
20
+ return x1, x2
models/invblock.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from models.rrdb_denselayer import ResidualDenseBlock_out
4
+
5
+
6
+ class INV_block(nn.Module):
7
+ def __init__(self, channel=2, subnet_constructor=ResidualDenseBlock_out, clamp=2.0):
8
+ super().__init__()
9
+ self.clamp = clamp
10
+
11
+ # ρ
12
+ self.r = subnet_constructor(channel, channel)
13
+ # η
14
+ self.y = subnet_constructor(channel, channel)
15
+ # φ
16
+ self.f = subnet_constructor(channel, channel)
17
+
18
+ def e(self, s):
19
+ return torch.exp(self.clamp * 2 * (torch.sigmoid(s) - 0.5))
20
+
21
+ def forward(self, x1, x2, rev=False):
22
+ if not rev:
23
+
24
+ t2 = self.f(x2)
25
+ y1 = x1 + t2
26
+ s1, t1 = self.r(y1), self.y(y1)
27
+ y2 = self.e(s1) * x2 + t1
28
+
29
+ else:
30
+
31
+ s1, t1 = self.r(x1), self.y(x1)
32
+ y2 = (x2 - t1) / self.e(s1)
33
+ t2 = self.f(y2)
34
+ y1 = (x1 - t2)
35
+
36
+ return y1, y2
models/module_util.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.init as init
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def initialize_weights(net_l, scale=1):
8
+ if not isinstance(net_l, list):
9
+ net_l = [net_l]
10
+ for net in net_l:
11
+ for m in net.modules():
12
+ if isinstance(m, nn.Conv2d):
13
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
14
+ m.weight.data *= scale # for residual block
15
+ if m.bias is not None:
16
+ m.bias.data.zero_()
17
+ elif isinstance(m, nn.Linear):
18
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
19
+ m.weight.data *= scale
20
+ if m.bias is not None:
21
+ m.bias.data.zero_()
22
+ elif isinstance(m, nn.BatchNorm2d):
23
+ init.constant_(m.weight, 1)
24
+ init.constant_(m.bias.data, 0.0)
25
+
26
+
27
+ def make_layer(block, n_layers):
28
+ layers = []
29
+ for _ in range(n_layers):
30
+ layers.append(block())
31
+ return nn.Sequential(*layers)
32
+
33
+
34
+ class ResidualBlock_noBN(nn.Module):
35
+ '''Residual block w/o BN
36
+ ---Conv-ReLU-Conv-+-
37
+ |________________|
38
+ '''
39
+
40
+ def __init__(self, nf=64):
41
+ super(ResidualBlock_noBN, self).__init__()
42
+ self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
43
+ self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
44
+
45
+ # initialization
46
+ initialize_weights([self.conv1, self.conv2], 0.1)
47
+
48
+ def forward(self, x):
49
+ identity = x
50
+ out = F.relu(self.conv1(x), inplace=True)
51
+ out = self.conv2(out)
52
+ return identity + out
53
+
54
+
55
+ def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
56
+ """Warp an image or feature map with optical flow
57
+ Args:
58
+ x (Tensor): size (N, C, H, W)
59
+ flow (Tensor): size (N, H, W, 2), normal value
60
+ interp_mode (str): 'nearest' or 'bilinear'
61
+ padding_mode (str): 'zeros' or 'border' or 'reflection'
62
+ Returns:
63
+ Tensor: warped image or feature map
64
+ """
65
+ flow = flow.permute(0,2,3,1)
66
+ assert x.size()[-2:] == flow.size()[1:3]
67
+ B, C, H, W = x.size()
68
+ # mesh grid
69
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
70
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
71
+ grid.requires_grad = False
72
+ grid = grid.type_as(x)
73
+ vgrid = grid + flow
74
+ # scale grid to [-1,1]
75
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
76
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
77
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
78
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
79
+ return output
models/my_model_v7_recover.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+
3
+ import torch.optim
4
+ import torch.nn as nn
5
+ from models.hinet import Hinet
6
+ # from utils.attacks import attack_layer, mp3_attack_v2, butterworth_attack
7
+ import numpy as np
8
+ import random
9
+
10
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
11
+
12
+
13
+ class Model(nn.Module):
14
+ def __init__(self, num_point, num_bit, n_fft, hop_length, use_recover_layer, num_layers):
15
+ super(Model, self).__init__()
16
+ self.hinet = Hinet(num_layers=num_layers)
17
+ self.watermark_fc = torch.nn.Linear(num_bit, num_point)
18
+ self.watermark_fc_back = torch.nn.Linear(num_point, num_bit)
19
+ self.n_fft = n_fft
20
+ self.hop_length = hop_length
21
+ self.dropout1 = torch.nn.Dropout()
22
+ self.identity = torch.nn.Identity()
23
+ self.recover_layer = SameSizeConv2d(2, 2)
24
+ self.use_recover_layer = use_recover_layer
25
+
26
+ def stft(self, data):
27
+ window = torch.hann_window(self.n_fft).to(data.device)
28
+ tmp = torch.stft(data, n_fft=self.n_fft, hop_length=self.hop_length, window=window, return_complex=False)
29
+ # [1, 501, 41, 2]
30
+ return tmp
31
+
32
+ def istft(self, signal_wmd_fft):
33
+ window = torch.hann_window(self.n_fft).to(signal_wmd_fft.device)
34
+
35
+ # Changed in version 2.0: Real datatype inputs are no longer supported. Input must now have a complex datatype, as returned by stft(..., return_complex=True).
36
+
37
+ return torch.istft(signal_wmd_fft, n_fft=self.n_fft, hop_length=self.hop_length, window=window,
38
+ return_complex=False)
39
+
40
+ def encode(self, signal, message, need_fft=False):
41
+ # 1.信号执行fft
42
+ signal_fft = self.stft(signal)
43
+ # import pdb
44
+ # pdb.set_trace()
45
+ # (batch,freq_bins,time_frames,2)
46
+
47
+ # 2.Message执行fft
48
+ message_expand = self.watermark_fc(message)
49
+ message_fft = self.stft(message_expand)
50
+
51
+ # 3.encode
52
+ signal_wmd_fft, msg_remain = self.enc_dec(signal_fft, message_fft, rev=False)
53
+ # (batch,freq_bins,time_frames,2)
54
+ signal_wmd = self.istft(signal_wmd_fft)
55
+ if need_fft:
56
+ return signal_wmd, signal_fft, message_fft
57
+
58
+ return signal_wmd
59
+
60
+ def decode(self, signal):
61
+ signal_fft = self.stft(signal)
62
+ if self.use_recover_layer:
63
+ signal_fft = self.recover_layer(signal_fft)
64
+ watermark_fft = signal_fft
65
+ # watermark_fft = torch.randn(signal_fft.shape).cuda()
66
+ _, message_restored_fft = self.enc_dec(signal_fft, watermark_fft, rev=True)
67
+ message_restored_expanded = self.istft(message_restored_fft)
68
+ message_restored_float = self.watermark_fc_back(message_restored_expanded).clamp(-1, 1)
69
+ return message_restored_float
70
+
71
+ def enc_dec(self, signal, watermark, rev):
72
+ signal = signal.permute(0, 3, 2, 1)
73
+ # [4, 2, 41, 501]
74
+
75
+ watermark = watermark.permute(0, 3, 2, 1)
76
+
77
+ # pdb.set_trace()
78
+ signal2, watermark2 = self.hinet(signal, watermark, rev)
79
+ return signal2.permute(0, 3, 2, 1), watermark2.permute(0, 3, 2, 1)
80
+
81
+
82
+ class SameSizeConv2d(nn.Module):
83
+ def __init__(self, in_channels, out_channels):
84
+ super(SameSizeConv2d, self).__init__()
85
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
86
+
87
+ def forward(self, x):
88
+ # (batch,501,41,2]
89
+ x1 = x.permute(0, 3, 1, 2)
90
+ # (batch,2,501,41]
91
+ x2 = self.conv(x1)
92
+ # (batch,2,501,41]
93
+ x3 = x2.permute(0, 2, 3, 1)
94
+ # (batch,501,41,2]
95
+ return x3
models/rrdb_denselayer.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import models.module_util as mutil
4
+
5
+
6
+ # Dense connection
7
+ class ResidualDenseBlock_out(nn.Module):
8
+ def __init__(self, in_channel, out_channel, bias=True):
9
+ super(ResidualDenseBlock_out, self).__init__()
10
+ self.conv1 = nn.Conv2d(in_channel, 32, 3, 1, 1, bias=bias)
11
+ self.conv2 = nn.Conv2d(in_channel + 32, 32, 3, 1, 1, bias=bias)
12
+ self.conv3 = nn.Conv2d(in_channel + 2 * 32, 32, 3, 1, 1, bias=bias)
13
+ self.conv4 = nn.Conv2d(in_channel + 3 * 32, 32, 3, 1, 1, bias=bias)
14
+ self.conv5 = nn.Conv2d(in_channel + 4 * 32, out_channel, 3, 1, 1, bias=bias)
15
+ self.lrelu = nn.LeakyReLU(inplace=True)
16
+ # initialization
17
+ mutil.initialize_weights([self.conv5], 0.)
18
+
19
+ def forward(self, x):
20
+ x1 = self.lrelu(self.conv1(x))
21
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
22
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
23
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
24
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
25
+ return x5
utils/__init__.py ADDED
File without changes
utils/bin_util.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def hexChar2binStr(v):
5
+ assert len(v) == 1
6
+ # e => '1110'
7
+ return '{0:04b}'.format(int(v, 16))
8
+
9
+
10
+ def hexStr2BinStr(hex_str):
11
+ output = [hexChar2binStr(c) for c in hex_str]
12
+ # ['1110', '1100', ....]
13
+ return "".join(output)
14
+
15
+
16
+ def hexStr2BinArray(hex_str):
17
+ # 十六进制字符串==> 0,1g构成的数组
18
+ tmp = hexStr2BinStr(hex_str)
19
+ return np.array([int(i) for i in tmp])
20
+
21
+
22
+ def binStr2HexStr(binary_str):
23
+ return hex(int(binary_str, 2))[2:]
24
+
25
+
26
+ def binArray2HexStr(bin_array):
27
+ tmp = "".join(["%d" % i for i in bin_array])
28
+ return binStr2HexStr(tmp)
29
+
30
+
31
+ # 判断是否为合法的16进制字符串
32
+ def is_hex_str(s):
33
+ hex_chars = "0123456789abcdefABCDEF"
34
+ return all(c in hex_chars for c in s)
35
+
36
+
37
+
38
+
39
+ def flip_bytearray(input_bytearray, num_bits_to_flip):
40
+ tmp = bytearray_to_binary_list(input_bytearray)
41
+ tmp = flip_array(tmp,num_bits_to_flip)
42
+ return binary_list_to_bytearray(tmp)
43
+
44
+ def flip_array(input_bits, num_bits_to_flip):
45
+
46
+ # 随机选择要翻转的位的索引
47
+ flip_indices = np.random.choice(len(input_bits), num_bits_to_flip, replace=False)
48
+
49
+ # 创建一个全零的掩码数组
50
+ mask = np.zeros_like(input_bits)
51
+
52
+ # 将选定的索引设置为 1
53
+ mask[flip_indices] = 1
54
+
55
+ # 将输入位数组与掩码进行逐元素异或运算,实现翻转位
56
+ flipped_bits = input_bits ^ mask
57
+ return flipped_bits
58
+
59
+
60
+
61
+ def bytearray_to_binary_list(byte_array):
62
+ binary_list = []
63
+ for byte in byte_array:
64
+ binary_str = format(byte, '08b') # 将字节转换为 8 位二进制字符串
65
+ binary_digits = [int(bit) for bit in binary_str] # 将二进制字符串转换为整数列表
66
+ binary_list.extend(binary_digits) # 将整数列表添加到结果列表中
67
+ return binary_list
68
+
69
+
70
+ def binary_list_to_bytearray(binary_list):
71
+ # 这个函数假设输入列表的长度是 8 的倍数,否则将引发异常。
72
+ byte_list = []
73
+ for i in range(0, len(binary_list), 8):
74
+ binary_str = ''.join(str(bit) for bit in binary_list[i:i + 8]) # 将 8 个位连接为一个二进制字符串
75
+ byte_value = int(binary_str, 2) # 将二进制字符串转换为整数
76
+ byte_list.append(byte_value) # 将整数添加到字节列表中
77
+ return bytearray(byte_list)
78
+
79
+
80
+
81
+
82
+ if __name__ == "__main__":
83
+ # hex_str = "ecd057f0d1fbb25d6430b338b5d72eb2"
84
+ # arr = hexStr2BinArray(hex_str)
85
+ # out = binArray2HexStr(arr)
86
+ # print(out==hex_str)
87
+ # bin_str = "".join()
88
+ # assert bin2hex_str(bin_str) == hex_str
89
+ # print(bin_str, len(bin_str))
90
+ #
91
+ watermark = np.random.randint(2, size=44)
92
+ res = binArray2HexStr(watermark)
93
+ print(res)
94
+
95
+ test_str1 = "3ad30c748a2"
96
+ test_str2 = "3ad30Z748a2"
97
+
98
+ print(is_hex_str(test_str1)) # 输出 True
99
+ print(is_hex_str(test_str2)) # 输出 False
100
+
101
+
102
+ # encode_file("1.wav", watermark)
103
+ # out = decode_file("tmp_output.wav")
104
+ # assert np.all(watermark == out)
utils/file_reader.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import soundfile
3
+ import librosa
4
+ import resampy
5
+
6
+
7
+ def is_wav_file(filename):
8
+ # 获取文件扩展名
9
+ file_extension = os.path.splitext(filename)[1]
10
+
11
+ # 判断文件扩展名是否为'.wav'或'.WAV'
12
+ return file_extension.lower() == ".wav"
13
+
14
+
15
+ import numpy as np
16
+
17
+
18
+ def read_as_single_channel_16k(audio_file, def_sr, verbose=False, aim_second=None):
19
+ assert os.path.exists(audio_file), "音频文件不存在"
20
+
21
+ file_extension = os.path.splitext(audio_file)[1].lower()
22
+
23
+ if file_extension == ".mp3":
24
+ data, origin_sr = librosa.load(audio_file, sr=None)
25
+ elif file_extension in [".wav", ".flac"]:
26
+ data, origin_sr = soundfile.read(audio_file)
27
+ else:
28
+ raise Exception("不支持的文件类型:" + file_extension)
29
+
30
+ # 通道数
31
+ if len(data.shape) == 2:
32
+ left_channel = data[:, 0]
33
+ if verbose:
34
+ print("双通道文件,变为单通道")
35
+ data = left_channel
36
+
37
+ # 采样率
38
+ if origin_sr != def_sr:
39
+ data = resampy.resample(data, origin_sr, def_sr)
40
+ if verbose:
41
+ print("原始音频采样率不是16kHZ,可能会对水印性能造成影响")
42
+
43
+ sr = def_sr
44
+ audio_length_second = 1.0 * len(data) / sr
45
+ if verbose:
46
+ print("输入音频长度:%d秒" % audio_length_second)
47
+
48
+ # 判断通道数
49
+ if len(data.shape) == 2:
50
+ data = data[:, 0]
51
+ print("选取第一个通道")
52
+
53
+ if aim_second is not None:
54
+ signal = data
55
+ assert len(signal) > 0
56
+ current_second = len(signal) / sr
57
+ if current_second < aim_second:
58
+ repeat_count = int(aim_second / current_second) + 1
59
+ signal = np.repeat(signal, repeat_count)
60
+ data = signal[0:sr * aim_second]
61
+
62
+ return data, sr, audio_length_second
63
+
64
+
65
+ def read_as_single_channel(file, aim_sr):
66
+ if file.endswith(".mp3"):
67
+ data, sr = librosa.load(file, sr=aim_sr) # 这里默认就是会转换为输入的sr
68
+ else:
69
+ data, sr = soundfile.read(file)
70
+
71
+ if len(data.shape) == 2: # 双声道
72
+ data = data[:, 0] # 只要第一个声道
73
+
74
+ # 然后再切换sr,因为soundfile可能读取出一个双通道的东西
75
+ if sr != aim_sr:
76
+ data = resampy.resample(data, sr, aim_sr)
77
+ return data
utils/metric_util.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import numpy as np
4
+
5
+
6
+ def calc_ber(watermark_decoded_tensor, watermark_tensor, threshold=0.5):
7
+ watermark_decoded_binary = watermark_decoded_tensor >= threshold
8
+ watermark_binary = watermark_tensor >= threshold
9
+ ber_tensor = 1 - (watermark_decoded_binary == watermark_binary).to(torch.float32).mean()
10
+ return ber_tensor
11
+
12
+
13
+ def to_equal_length(original, signal_watermarked):
14
+ if original.shape != signal_watermarked.shape:
15
+ print("警告!输入内容长度不一致", len(original), len(signal_watermarked))
16
+ min_length = min(len(original), len(signal_watermarked))
17
+ original = original[0:min_length]
18
+ signal_watermarked = signal_watermarked[0:min_length]
19
+ assert original.shape == signal_watermarked.shape
20
+ return original, signal_watermarked
21
+
22
+
23
+ def signal_noise_ratio(original, signal_watermarked):
24
+ # 数值越高越好,最好的结果为无穷大
25
+ original, signal_watermarked = to_equal_length(original, signal_watermarked)
26
+ noise_strength = np.sum((original - signal_watermarked) ** 2)
27
+ if noise_strength == 0: # 说明原始信号并未改变
28
+ return np.inf
29
+ signal_strength = np.sum(original ** 2)
30
+ ratio = signal_strength / noise_strength
31
+
32
+ # np.log10(1) == 0
33
+ # 当噪声比信号强度还高时,信噪比就是负的
34
+ # 如果ratio是0,那么 np.log10(0) 就是负无穷 -inf
35
+ # 这里限定一个最小值,以免出现负无穷情况
36
+ ratio = max(1e-10, ratio)
37
+ return 10 * np.log10(ratio)
38
+
39
+
40
+ def batch_signal_noise_ratio(original, signal_watermarked):
41
+ signal = original.detach().cpu().numpy()
42
+ signal_watermarked = signal_watermarked.detach().cpu().numpy()
43
+ tmp_list = []
44
+ for s, swm in zip(signal, signal_watermarked):
45
+ out = signal_noise_ratio(s, swm)
46
+ tmp_list.append(out)
47
+ return np.mean(tmp_list)
48
+
49
+
50
+ def calc_bce_acc(predictions, ground_truth, threshold=0.5):
51
+ assert predictions.shape == ground_truth.shape
52
+
53
+ # 将预测值转换为类别标签
54
+ predicted_labels = (predictions >= threshold).float()
55
+
56
+ # 计算准确率
57
+ accuracy = ((predicted_labels == ground_truth).float().mean().item())
58
+ return accuracy
59
+
60
+
61
+ def resample_to16k(data, old_sr):
62
+ # 对数据进行重采样
63
+ new_fs = 16000
64
+ new_data = data[::int(old_sr / new_fs)]
65
+ return new_data
66
+
67
+
68
+ import pypesq
69
+
70
+
71
+ def pesq(signal1, signal2, sr):
72
+ signal1, signal2 = to_equal_length(signal1, signal2)
73
+
74
+ # Perceptual Evaluation of Speech Quality
75
+ # [−0.5 to 4.5], PESQ>3.5 时音频质量较好,>4.0基本上就听不到了
76
+ # 函数只支持16k或8k的输入,因此在输入前校验采样率。由于这个指标计算的是可感知性,因此这里改变采样率和水印鲁棒性是无关的
77
+ if sr != 16000:
78
+ signal1 = resample_to16k(signal1, sr)
79
+ signal2 = resample_to16k(signal2, sr)
80
+
81
+ try:
82
+ pesq = pypesq.pesq(signal1, signal2, 16000)
83
+ # 可能会有错误:ValueError: ref is all zeros, processing error!
84
+ except Exception as e:
85
+ pesq = 0
86
+ print("pesq计算错误:", e)
87
+
88
+ return pesq
utils/model_util.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import json
4
+ import sys
5
+ from utils import pickle_util
6
+
7
+ history_array = []
8
+
9
+
10
+ def save_model(epoch, model, optimizer, file_save_path):
11
+ dirpath = os.path.abspath(os.path.join(file_save_path, os.pardir))
12
+ if not os.path.exists(dirpath):
13
+ print("mkdir:", dirpath)
14
+ os.makedirs(dirpath)
15
+
16
+ opti = None
17
+ if optimizer is not None:
18
+ opti = optimizer.state_dict()
19
+
20
+ torch.save(obj={
21
+ 'epoch': epoch,
22
+ 'model': model.state_dict(),
23
+ 'optimizer': opti,
24
+ }, f=file_save_path)
25
+
26
+ history_array.append(file_save_path)
27
+
28
+
29
+ def save_model_v4(epoch, model, optimizer, file_save_path, discriminator):
30
+ dirpath = os.path.abspath(os.path.join(file_save_path, os.pardir))
31
+ if not os.path.exists(dirpath):
32
+ print("mkdir:", dirpath)
33
+ os.makedirs(dirpath)
34
+
35
+ opti = None
36
+ if optimizer is not None:
37
+ opti = optimizer.state_dict()
38
+
39
+ torch.save(obj={
40
+ 'epoch': epoch,
41
+ 'model': model.state_dict(),
42
+ 'optimizer': opti,
43
+ "discriminator": discriminator,
44
+ }, f=file_save_path)
45
+
46
+ history_array.append(file_save_path)
47
+
48
+
49
+ def delete_last_saved_model():
50
+ if len(history_array) == 0:
51
+ return
52
+ last_path = history_array.pop()
53
+ if os.path.exists(last_path):
54
+ os.remove(last_path)
55
+ print("delete model:", last_path)
56
+
57
+ if os.path.exists(last_path + ".json"):
58
+ os.remove(last_path + ".json")
59
+
60
+
61
+ def load_model(resume_path, model, optimizer=None, strict=True):
62
+ checkpoint = torch.load(resume_path, map_location=torch.device('cpu'))
63
+ start_epoch = checkpoint['epoch'] + 1
64
+ model.load_state_dict(checkpoint['model'], strict=strict)
65
+ if optimizer is not None:
66
+ optimizer.load_state_dict(checkpoint['optimizer'])
67
+ print("checkpoint loaded!")
68
+ return start_epoch
69
+
70
+
71
+ def save_model_v2(model, args, model_save_name):
72
+ model_save_path = os.path.join(args.model_save_folder, args.project, args.name, model_save_name)
73
+ save_model(0, model, None, model_save_path)
74
+ print("save:", model_save_path)
75
+
76
+
77
+ def save_project_info(args):
78
+ run_info = {
79
+ "cmd_str": ' '.join(sys.argv[1:]),
80
+ "args": vars(args),
81
+ }
82
+
83
+ name = "run_info.json"
84
+ folder = os.path.join(args.model_save_folder, args.project, args.name)
85
+ if not os.path.exists(folder):
86
+ os.makedirs(folder)
87
+
88
+ json_file_path = os.path.join(folder, name)
89
+ with open(json_file_path, "w") as f:
90
+ json.dump(run_info, f)
91
+
92
+ print("save_project_info:", json_file_path)
93
+
94
+
95
+ def get_pkl_json(folder):
96
+ names = [i for i in os.listdir(folder) if ".pkl.json" in i]
97
+ assert len(names) == 1
98
+ json_path = os.path.join(folder, names[0])
99
+ obj = pickle_util.read_json(json_path)
100
+ return obj
101
+
102
+
103
+ # 并行
104
+
105
+ def is_data_parallel_checkpoint(state_dict):
106
+ return any(key.startswith('module.') for key in state_dict.keys())
107
+
108
+
109
+ def map_state_dict(state_dict):
110
+ if is_data_parallel_checkpoint(state_dict):
111
+ # 处理 DataParallel 添加的前缀 'module.'
112
+ from collections import OrderedDict
113
+ new_state_dict = OrderedDict()
114
+ for k, v in state_dict.items():
115
+ name = k[7:] if k.startswith('module.') else k # 移除前缀 'module.'
116
+ new_state_dict[name] = v
117
+ return new_state_dict
118
+ return state_dict
utils/pesq_util.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pypesq
2
+ import numpy as np
3
+
4
+
5
+ def batch_pesq(batch_signal, batch_signal_wmd):
6
+ batch_signal1 = batch_signal.detach().cpu().numpy()
7
+ batch_signal2 = batch_signal_wmd.detach().cpu().numpy()
8
+ pesq_array = []
9
+ for signal1, signal2 in zip(batch_signal1, batch_signal2):
10
+ try:
11
+ pesq = pypesq.pesq(signal1, signal2, 16000)
12
+ #可能会有错误:ValueError: ref is all zeros, processing error!
13
+
14
+ except Exception as e:
15
+ print(e)
16
+
17
+ continue
18
+ if np.isnan(pesq):
19
+ print("pesq is nan!")
20
+ continue
21
+ pesq_array.append(pesq)
22
+
23
+ if len(pesq_array) > 0:
24
+ return np.mean(pesq_array)
25
+ return -1
utils/pickle_util.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import _pickle as pickle # python3
2
+ import time
3
+ import json
4
+
5
+
6
+ def read_pickle(filepath):
7
+ f = open(filepath, 'rb')
8
+ word2mfccs = pickle.load(f)
9
+ f.close()
10
+ return word2mfccs
11
+
12
+
13
+ def save_pickle(save_path, save_data):
14
+ f = open(save_path, 'wb')
15
+ pickle.dump(save_data, f)
16
+ f.close()
17
+
18
+
19
+ def read_json(filepath):
20
+ with open(filepath) as f:
21
+ obj = json.load(f)
22
+ return obj
23
+
24
+
25
+ def save_json(save_path, obj):
26
+ with open(save_path, 'w') as f:
27
+ json.dump(obj, f)
utils/silent_util.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def is_silent(data, silence_threshold=0.01):
5
+ rms = np.sqrt(np.mean(data ** 2))
6
+ return rms < silence_threshold
7
+
8
+
9
+ def has_silent_part(trunck):
10
+ num_part = 3
11
+ part_length = int(len(trunck) / num_part)
12
+ for i in range(num_part):
13
+ start = part_length * i
14
+ end = start + part_length
15
+ mini_trunck = trunck[start:end]
16
+ if is_silent(mini_trunck):
17
+ return True
18
+ return False
utils/wm_add_v2.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import silent_util
2
+ import torch
3
+ import numpy as np
4
+ from utils import bin_util
5
+
6
+ fix_pattern = [1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0,
7
+ 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1,
8
+ 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1,
9
+ 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0,
10
+ 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0]
11
+
12
+
13
+ def create_parcel_message(len_start_bit, num_bit, wm_text, verbose=False):
14
+ # 2.起始bit
15
+ # start_bit = np.array([0] * len_start_bit)
16
+ start_bit = fix_pattern[0:len_start_bit]
17
+ error_prob = 2 ** len_start_bit / 10000
18
+ # todo:考虑threshold的时候的错误率呢?
19
+ if verbose:
20
+ print("起始bit长度:%d,错误率:%.1f万" % (len(start_bit), error_prob))
21
+
22
+ # 3.信息内容
23
+ length_msg = num_bit - len(start_bit)
24
+ if wm_text:
25
+ msg_arr = bin_util.hexStr2BinArray(wm_text)
26
+ else:
27
+ msg_arr = np.random.choice([0, 1], size=length_msg)
28
+
29
+ # 4.封装信息
30
+ watermark = np.concatenate([start_bit, msg_arr])
31
+ assert len(watermark) == num_bit
32
+ return start_bit, msg_arr, watermark
33
+
34
+
35
+ import time
36
+
37
+
38
+ def add_watermark(bir_array, data, num_point, shift_range, device, model, silence_check=False):
39
+ t1 = time.time()
40
+ # 1.获得区块大小
41
+ chunk_size = num_point + int(num_point * shift_range)
42
+
43
+ output_chunks = []
44
+ idx_trunck = -1
45
+ for i in range(0, len(data), chunk_size):
46
+ idx_trunck += 1
47
+ current_chunk = data[i:i + chunk_size].copy()
48
+ # 最后一块,长度不足
49
+ if len(current_chunk) < chunk_size:
50
+ output_chunks.append(current_chunk)
51
+ break
52
+
53
+ # 处理区块: [水印区|间隔区]
54
+ current_chunk_cover_area = current_chunk[0:num_point]
55
+ current_chunk_shift_area = current_chunk[num_point:]
56
+ current_chunk_cover_area_wmd = encode_trunck_with_silence_check(silence_check,
57
+ idx_trunck,
58
+ current_chunk_cover_area, bir_array,
59
+ device, model)
60
+ output = np.concatenate([current_chunk_cover_area_wmd, current_chunk_shift_area])
61
+ assert output.shape == current_chunk.shape
62
+ output_chunks.append(output)
63
+
64
+ assert len(output_chunks) > 0
65
+ reconstructed_array = np.concatenate(output_chunks)
66
+ time_cost = time.time() - t1
67
+ return data, reconstructed_array, time_cost
68
+
69
+
70
+ def encode_trunck_with_silence_check(silence_check, trunck_idx, trunck, wm, device, model):
71
+ # 1.判断是否是静音,通过判断子段是否静音来处理
72
+ if silence_check and silent_util.is_silent(trunck):
73
+ print("跳过静音区块:", trunck_idx)
74
+ return trunck
75
+
76
+ # 2.加入水印
77
+ trnck_wmd = encode_trunck(trunck, wm, device, model)
78
+ return trnck_wmd
79
+
80
+
81
+ def encode_trunck(trunck, wm, device, model):
82
+ with torch.no_grad():
83
+ signal = torch.FloatTensor(trunck).to(device)[None]
84
+ message = torch.FloatTensor(np.array(wm)).to(device)[None]
85
+ signal_wmd_tensor = model.encode(signal, message)
86
+ signal_wmd = signal_wmd_tensor.detach().cpu().numpy().squeeze()
87
+ return signal_wmd
utils/wm_decode_v2.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+
3
+ import torch
4
+ import numpy as np
5
+ from utils import bin_util
6
+
7
+
8
+ def decode_trunck(trunck, model, device):
9
+ with torch.no_grad():
10
+ signal = torch.FloatTensor(trunck).to(device).unsqueeze(0)
11
+ message = (model.decode(signal) >= 0.5).int()
12
+ message = message.detach().cpu().numpy().squeeze()
13
+ return message
14
+
15
+
16
+ def is_start_bit_match(start_bit, decoded_start_bit, start_bit_ber_threshold):
17
+ assert decoded_start_bit.shape == start_bit.shape
18
+ ber = 1 - np.mean(start_bit == decoded_start_bit)
19
+ return ber < start_bit_ber_threshold
20
+
21
+
22
+ def extract_watermark(data, start_bit, shift_range, num_point, start_bit_ber_threshold, model, device,
23
+ verbose=False):
24
+ # pdb.set_trace()
25
+ shift_range_points = int(shift_range * num_point)
26
+ i = 0 # 当前的指针位置
27
+ results = []
28
+ while True:
29
+ start = i
30
+ end = start + num_point
31
+ trunck = data[start:end]
32
+ if len(trunck) < num_point:
33
+ break
34
+
35
+ bit_array = decode_trunck(trunck, model, device)
36
+ decoded_start_bit = bit_array[0:len(start_bit)]
37
+ if not is_start_bit_match(start_bit, decoded_start_bit, start_bit_ber_threshold):
38
+ i = i + shift_range_points
39
+ continue
40
+ # 寻找到了起始位置
41
+ if verbose:
42
+ msg_bit = bit_array[len(start_bit):]
43
+ msg_str = bin_util.binArray2HexStr(msg_bit)
44
+ print(i, "解码信息:", msg_str)
45
+ results.append(bit_array)
46
+ i = i + num_point + shift_range_points
47
+
48
+ support_count = len(results)
49
+ if support_count == 0:
50
+ mean_result = None
51
+ first_result = None
52
+ exist_prob = None
53
+ else:
54
+ mean_result = (np.array(results).mean(axis=0) >= 0.5).astype(int)
55
+ exist_prob = (mean_result[0:len(start_bit)] == start_bit).mean()
56
+ first_result = results[0]
57
+
58
+ return support_count, exist_prob, mean_result, first_result
59
+
60
+
61
+ def extract_watermark_v2(data, start_bit, shift_range, num_point,
62
+ start_bit_ber_threshold, model, device,
63
+ merge_type,
64
+ shift_range_p=0.5, ):
65
+ shift_range_points = int(shift_range * num_point * shift_range_p)
66
+ i = 0 # 当前的指针位置
67
+ results = []
68
+ while True:
69
+ start = i
70
+ end = start + num_point
71
+ trunck = data[start:end]
72
+ if len(trunck) < num_point:
73
+ break
74
+
75
+ bit_array = decode_trunck(trunck, model, device)
76
+ decoded_start_bit = bit_array[0:len(start_bit)]
77
+
78
+ ber_start_bit = 1 - np.mean(start_bit == decoded_start_bit)
79
+ if ber_start_bit > start_bit_ber_threshold:
80
+ i = i + shift_range_points
81
+ continue
82
+ # 寻找到了起始位置
83
+ results.append({
84
+ "sim": 1 - ber_start_bit,
85
+ "msg": bit_array,
86
+ })
87
+ # 这里很重要,如果threshold设置的太大,那么就会跳过一些可能的点
88
+ # i = i + num_point + shift_range_points
89
+ i = i + shift_range_points
90
+
91
+ support_count = len(results)
92
+ if support_count == 0:
93
+ mean_result = None
94
+ else:
95
+ # 1.加权得到最终结果
96
+ if merge_type == "weighted":
97
+ raise Exception("")
98
+ elif merge_type == "best":
99
+ # 相似度从大到小排序
100
+ best_val = sorted(results, key=lambda x: x["sim"], reverse=True)[0]
101
+ if np.isclose(1.0, best_val["sim"]):
102
+ # 那么对所有为1.0的进行求平均
103
+ results_1 = [i["msg"] for i in results if np.isclose(i["sim"], 1.0)]
104
+ mean_result = (np.array(results_1).mean(axis=0) >= 0.5).astype(int)
105
+ else:
106
+ mean_result = best_val["msg"]
107
+
108
+ else:
109
+ raise Exception("")
110
+ # assert merge_type == "mean"
111
+ # mean_result = (np.array([i[-1] for i in results]).mean(axis=0) >= 0.5).astype(int)
112
+
113
+ return support_count, mean_result, results