WavMark / utils /wm_add_v2.py
my
Add application file
32ca76b
raw
history blame
No virus
3.46 kB
from utils import silent_util
import torch
import numpy as np
from utils import bin_util
fix_pattern = [1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1,
1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1,
1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0,
0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0]
def create_parcel_message(len_start_bit, num_bit, wm_text, verbose=False):
# 2.起始bit
# start_bit = np.array([0] * len_start_bit)
start_bit = fix_pattern[0:len_start_bit]
error_prob = 2 ** len_start_bit / 10000
# todo:考虑threshold的时候的错误率呢?
if verbose:
print("起始bit长度:%d,错误率:%.1f万" % (len(start_bit), error_prob))
# 3.信息内容
length_msg = num_bit - len(start_bit)
if wm_text:
msg_arr = bin_util.hexStr2BinArray(wm_text)
else:
msg_arr = np.random.choice([0, 1], size=length_msg)
# 4.封装信息
watermark = np.concatenate([start_bit, msg_arr])
assert len(watermark) == num_bit
return start_bit, msg_arr, watermark
import time
def add_watermark(bir_array, data, num_point, shift_range, device, model, silence_check=False):
t1 = time.time()
# 1.获得区块大小
chunk_size = num_point + int(num_point * shift_range)
output_chunks = []
idx_trunck = -1
for i in range(0, len(data), chunk_size):
idx_trunck += 1
current_chunk = data[i:i + chunk_size].copy()
# 最后一块,长度不足
if len(current_chunk) < chunk_size:
output_chunks.append(current_chunk)
break
# 处理区块: [水印区|间隔区]
current_chunk_cover_area = current_chunk[0:num_point]
current_chunk_shift_area = current_chunk[num_point:]
current_chunk_cover_area_wmd = encode_trunck_with_silence_check(silence_check,
idx_trunck,
current_chunk_cover_area, bir_array,
device, model)
output = np.concatenate([current_chunk_cover_area_wmd, current_chunk_shift_area])
assert output.shape == current_chunk.shape
output_chunks.append(output)
assert len(output_chunks) > 0
reconstructed_array = np.concatenate(output_chunks)
time_cost = time.time() - t1
return data, reconstructed_array, time_cost
def encode_trunck_with_silence_check(silence_check, trunck_idx, trunck, wm, device, model):
# 1.判断是否是静音,通过判断子段是否静音来处理
if silence_check and silent_util.is_silent(trunck):
print("跳过静音区块:", trunck_idx)
return trunck
# 2.加入水印
trnck_wmd = encode_trunck(trunck, wm, device, model)
return trnck_wmd
def encode_trunck(trunck, wm, device, model):
with torch.no_grad():
signal = torch.FloatTensor(trunck).to(device)[None]
message = torch.FloatTensor(np.array(wm)).to(device)[None]
signal_wmd_tensor = model.encode(signal, message)
signal_wmd = signal_wmd_tensor.detach().cpu().numpy().squeeze()
return signal_wmd