WavMark / utils /wm_decode_v2.py
my
Add application file
32ca76b
raw
history blame
No virus
4.1 kB
import pdb
import torch
import numpy as np
from utils import bin_util
def decode_trunck(trunck, model, device):
with torch.no_grad():
signal = torch.FloatTensor(trunck).to(device).unsqueeze(0)
message = (model.decode(signal) >= 0.5).int()
message = message.detach().cpu().numpy().squeeze()
return message
def is_start_bit_match(start_bit, decoded_start_bit, start_bit_ber_threshold):
assert decoded_start_bit.shape == start_bit.shape
ber = 1 - np.mean(start_bit == decoded_start_bit)
return ber < start_bit_ber_threshold
def extract_watermark(data, start_bit, shift_range, num_point, start_bit_ber_threshold, model, device,
verbose=False):
# pdb.set_trace()
shift_range_points = int(shift_range * num_point)
i = 0 # 当前的指针位置
results = []
while True:
start = i
end = start + num_point
trunck = data[start:end]
if len(trunck) < num_point:
break
bit_array = decode_trunck(trunck, model, device)
decoded_start_bit = bit_array[0:len(start_bit)]
if not is_start_bit_match(start_bit, decoded_start_bit, start_bit_ber_threshold):
i = i + shift_range_points
continue
# 寻找到了起始位置
if verbose:
msg_bit = bit_array[len(start_bit):]
msg_str = bin_util.binArray2HexStr(msg_bit)
print(i, "解码信息:", msg_str)
results.append(bit_array)
i = i + num_point + shift_range_points
support_count = len(results)
if support_count == 0:
mean_result = None
first_result = None
exist_prob = None
else:
mean_result = (np.array(results).mean(axis=0) >= 0.5).astype(int)
exist_prob = (mean_result[0:len(start_bit)] == start_bit).mean()
first_result = results[0]
return support_count, exist_prob, mean_result, first_result
def extract_watermark_v2(data, start_bit, shift_range, num_point,
start_bit_ber_threshold, model, device,
merge_type,
shift_range_p=0.5, ):
shift_range_points = int(shift_range * num_point * shift_range_p)
i = 0 # 当前的指针位置
results = []
while True:
start = i
end = start + num_point
trunck = data[start:end]
if len(trunck) < num_point:
break
bit_array = decode_trunck(trunck, model, device)
decoded_start_bit = bit_array[0:len(start_bit)]
ber_start_bit = 1 - np.mean(start_bit == decoded_start_bit)
if ber_start_bit > start_bit_ber_threshold:
i = i + shift_range_points
continue
# 寻找到了起始位置
results.append({
"sim": 1 - ber_start_bit,
"msg": bit_array,
})
# 这里很重要,如果threshold设置的太大,那么就会跳过一些可能的点
# i = i + num_point + shift_range_points
i = i + shift_range_points
support_count = len(results)
if support_count == 0:
mean_result = None
else:
# 1.加权得到最终结果
if merge_type == "weighted":
raise Exception("")
elif merge_type == "best":
# 相似度从大到小排序
best_val = sorted(results, key=lambda x: x["sim"], reverse=True)[0]
if np.isclose(1.0, best_val["sim"]):
# 那么对所有为1.0的进行求平均
results_1 = [i["msg"] for i in results if np.isclose(i["sim"], 1.0)]
mean_result = (np.array(results_1).mean(axis=0) >= 0.5).astype(int)
else:
mean_result = best_val["msg"]
else:
raise Exception("")
# assert merge_type == "mean"
# mean_result = (np.array([i[-1] for i in results]).mean(axis=0) >= 0.5).astype(int)
return support_count, mean_result, results