Spaces:
Running
on
T4
Running
on
T4
import pdb | |
import torch | |
import numpy as np | |
from utils import bin_util | |
def decode_trunck(trunck, model, device): | |
with torch.no_grad(): | |
signal = torch.FloatTensor(trunck).to(device).unsqueeze(0) | |
message = (model.decode(signal) >= 0.5).int() | |
message = message.detach().cpu().numpy().squeeze() | |
return message | |
def is_start_bit_match(start_bit, decoded_start_bit, start_bit_ber_threshold): | |
assert decoded_start_bit.shape == start_bit.shape | |
ber = 1 - np.mean(start_bit == decoded_start_bit) | |
return ber < start_bit_ber_threshold | |
def extract_watermark(data, start_bit, shift_range, num_point, start_bit_ber_threshold, model, device, | |
verbose=False): | |
# pdb.set_trace() | |
shift_range_points = int(shift_range * num_point) | |
i = 0 # 当前的指针位置 | |
results = [] | |
while True: | |
start = i | |
end = start + num_point | |
trunck = data[start:end] | |
if len(trunck) < num_point: | |
break | |
bit_array = decode_trunck(trunck, model, device) | |
decoded_start_bit = bit_array[0:len(start_bit)] | |
if not is_start_bit_match(start_bit, decoded_start_bit, start_bit_ber_threshold): | |
i = i + shift_range_points | |
continue | |
# 寻找到了起始位置 | |
if verbose: | |
msg_bit = bit_array[len(start_bit):] | |
msg_str = bin_util.binArray2HexStr(msg_bit) | |
print(i, "解码信息:", msg_str) | |
results.append(bit_array) | |
i = i + num_point + shift_range_points | |
support_count = len(results) | |
if support_count == 0: | |
mean_result = None | |
first_result = None | |
exist_prob = None | |
else: | |
mean_result = (np.array(results).mean(axis=0) >= 0.5).astype(int) | |
exist_prob = (mean_result[0:len(start_bit)] == start_bit).mean() | |
first_result = results[0] | |
return support_count, exist_prob, mean_result, first_result | |
def extract_watermark_v2(data, start_bit, shift_range, num_point, | |
start_bit_ber_threshold, model, device, | |
merge_type, | |
shift_range_p=0.5, ): | |
shift_range_points = int(shift_range * num_point * shift_range_p) | |
i = 0 # 当前的指针位置 | |
results = [] | |
while True: | |
start = i | |
end = start + num_point | |
trunck = data[start:end] | |
if len(trunck) < num_point: | |
break | |
bit_array = decode_trunck(trunck, model, device) | |
decoded_start_bit = bit_array[0:len(start_bit)] | |
ber_start_bit = 1 - np.mean(start_bit == decoded_start_bit) | |
if ber_start_bit > start_bit_ber_threshold: | |
i = i + shift_range_points | |
continue | |
# 寻找到了起始位置 | |
results.append({ | |
"sim": 1 - ber_start_bit, | |
"msg": bit_array, | |
}) | |
# 这里很重要,如果threshold设置的太大,那么就会跳过一些可能的点 | |
# i = i + num_point + shift_range_points | |
i = i + shift_range_points | |
support_count = len(results) | |
if support_count == 0: | |
mean_result = None | |
else: | |
# 1.加权得到最终结果 | |
if merge_type == "weighted": | |
raise Exception("") | |
elif merge_type == "best": | |
# 相似度从大到小排序 | |
best_val = sorted(results, key=lambda x: x["sim"], reverse=True)[0] | |
if np.isclose(1.0, best_val["sim"]): | |
# 那么对所有为1.0的进行求平均 | |
results_1 = [i["msg"] for i in results if np.isclose(i["sim"], 1.0)] | |
mean_result = (np.array(results_1).mean(axis=0) >= 0.5).astype(int) | |
else: | |
mean_result = best_val["msg"] | |
else: | |
raise Exception("") | |
# assert merge_type == "mean" | |
# mean_result = (np.array([i[-1] for i in results]).mean(axis=0) >= 0.5).astype(int) | |
return support_count, mean_result, results | |