File size: 4,100 Bytes
32ca76b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import pdb

import torch
import numpy as np
from utils import bin_util


def decode_trunck(trunck, model, device):
    with torch.no_grad():
        signal = torch.FloatTensor(trunck).to(device).unsqueeze(0)
        message = (model.decode(signal) >= 0.5).int()
        message = message.detach().cpu().numpy().squeeze()
    return message


def is_start_bit_match(start_bit, decoded_start_bit, start_bit_ber_threshold):
    assert decoded_start_bit.shape == start_bit.shape
    ber = 1 - np.mean(start_bit == decoded_start_bit)
    return ber < start_bit_ber_threshold


def extract_watermark(data, start_bit, shift_range, num_point, start_bit_ber_threshold, model, device,

                      verbose=False):
    # pdb.set_trace()
    shift_range_points = int(shift_range * num_point)
    i = 0  # 当前的指针位置
    results = []
    while True:
        start = i
        end = start + num_point
        trunck = data[start:end]
        if len(trunck) < num_point:
            break

        bit_array = decode_trunck(trunck, model, device)
        decoded_start_bit = bit_array[0:len(start_bit)]
        if not is_start_bit_match(start_bit, decoded_start_bit, start_bit_ber_threshold):
            i = i + shift_range_points
            continue
        # 寻找到了起始位置
        if verbose:
            msg_bit = bit_array[len(start_bit):]
            msg_str = bin_util.binArray2HexStr(msg_bit)
            print(i, "解码信息:", msg_str)
        results.append(bit_array)
        i = i + num_point + shift_range_points

    support_count = len(results)
    if support_count == 0:
        mean_result = None
        first_result = None
        exist_prob = None
    else:
        mean_result = (np.array(results).mean(axis=0) >= 0.5).astype(int)
        exist_prob = (mean_result[0:len(start_bit)] == start_bit).mean()
        first_result = results[0]

    return support_count, exist_prob, mean_result, first_result


def extract_watermark_v2(data, start_bit, shift_range, num_point,

                         start_bit_ber_threshold, model, device,

                         merge_type,

                         shift_range_p=0.5, ):
    shift_range_points = int(shift_range * num_point * shift_range_p)
    i = 0  # 当前的指针位置
    results = []
    while True:
        start = i
        end = start + num_point
        trunck = data[start:end]
        if len(trunck) < num_point:
            break

        bit_array = decode_trunck(trunck, model, device)
        decoded_start_bit = bit_array[0:len(start_bit)]

        ber_start_bit = 1 - np.mean(start_bit == decoded_start_bit)
        if ber_start_bit > start_bit_ber_threshold:
            i = i + shift_range_points
            continue
        # 寻找到了起始位置
        results.append({
            "sim": 1 - ber_start_bit,
            "msg": bit_array,
        })
        # 这里很重要,如果threshold设置的太大,那么就会跳过一些可能的点
        # i = i + num_point + shift_range_points
        i = i + shift_range_points

    support_count = len(results)
    if support_count == 0:
        mean_result = None
    else:
        # 1.加权得到最终结果
        if merge_type == "weighted":
            raise Exception("")
        elif merge_type == "best":
            # 相似度从大到小排序
            best_val = sorted(results, key=lambda x: x["sim"], reverse=True)[0]
            if np.isclose(1.0, best_val["sim"]):
                # 那么对所有为1.0的进行求平均
                results_1 = [i["msg"] for i in results if np.isclose(i["sim"], 1.0)]
                mean_result = (np.array(results_1).mean(axis=0) >= 0.5).astype(int)
            else:
                mean_result = best_val["msg"]

        else:
            raise Exception("")
            # assert merge_type == "mean"
            # mean_result = (np.array([i[-1] for i in results]).mean(axis=0) >= 0.5).astype(int)

    return support_count, mean_result, results