File size: 3,461 Bytes
32ca76b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from utils import silent_util
import torch
import numpy as np
from utils import bin_util

fix_pattern = [1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0,
               0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1,
               1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1,
               1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0,
               0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0]


def create_parcel_message(len_start_bit, num_bit, wm_text, verbose=False):
    # 2.起始bit
    # start_bit = np.array([0] * len_start_bit)
    start_bit = fix_pattern[0:len_start_bit]
    error_prob = 2 ** len_start_bit / 10000
    # todo:考虑threshold的时候的错误率呢?
    if verbose:
        print("起始bit长度:%d,错误率:%.1f万" % (len(start_bit), error_prob))

    # 3.信息内容
    length_msg = num_bit - len(start_bit)
    if wm_text:
        msg_arr = bin_util.hexStr2BinArray(wm_text)
    else:
        msg_arr = np.random.choice([0, 1], size=length_msg)

    # 4.封装信息
    watermark = np.concatenate([start_bit, msg_arr])
    assert len(watermark) == num_bit
    return start_bit, msg_arr, watermark


import time


def add_watermark(bir_array, data, num_point, shift_range, device, model, silence_check=False):
    t1 = time.time()
    # 1.获得区块大小
    chunk_size = num_point + int(num_point * shift_range)

    output_chunks = []
    idx_trunck = -1
    for i in range(0, len(data), chunk_size):
        idx_trunck += 1
        current_chunk = data[i:i + chunk_size].copy()
        # 最后一块,长度不足
        if len(current_chunk) < chunk_size:
            output_chunks.append(current_chunk)
            break

        # 处理区块: [水印区|间隔区]
        current_chunk_cover_area = current_chunk[0:num_point]
        current_chunk_shift_area = current_chunk[num_point:]
        current_chunk_cover_area_wmd = encode_trunck_with_silence_check(silence_check,
                                                                        idx_trunck,
                                                                        current_chunk_cover_area, bir_array,
                                                                        device, model)
        output = np.concatenate([current_chunk_cover_area_wmd, current_chunk_shift_area])
        assert output.shape == current_chunk.shape
        output_chunks.append(output)

    assert len(output_chunks) > 0
    reconstructed_array = np.concatenate(output_chunks)
    time_cost = time.time() - t1
    return data, reconstructed_array, time_cost


def encode_trunck_with_silence_check(silence_check, trunck_idx, trunck, wm, device, model):
    # 1.判断是否是静音,通过判断子段是否静音来处理
    if silence_check and silent_util.is_silent(trunck):
        print("跳过静音区块:", trunck_idx)
        return trunck

    # 2.加入水印
    trnck_wmd = encode_trunck(trunck, wm, device, model)
    return trnck_wmd


def encode_trunck(trunck, wm, device, model):
    with torch.no_grad():
        signal = torch.FloatTensor(trunck).to(device)[None]
        message = torch.FloatTensor(np.array(wm)).to(device)[None]
        signal_wmd_tensor = model.encode(signal, message)
        signal_wmd = signal_wmd_tensor.detach().cpu().numpy().squeeze()
        return signal_wmd