SWilliams20
commited on
Commit
•
994be6c
1
Parent(s):
261683e
Update watermark_function.py
Browse files- watermark_function.py +65 -37
watermark_function.py
CHANGED
@@ -1,50 +1,78 @@
|
|
1 |
-
#
|
2 |
|
3 |
-
# Import necessary libraries
|
4 |
import numpy as np
|
5 |
-
import
|
|
|
|
|
6 |
|
7 |
-
# Function to
|
8 |
-
def
|
9 |
-
|
|
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
14 |
|
15 |
-
|
16 |
-
|
|
|
|
|
17 |
|
18 |
-
#
|
19 |
-
|
20 |
|
21 |
-
#
|
22 |
-
|
|
|
|
|
|
|
23 |
|
24 |
-
#
|
25 |
-
|
26 |
-
# Note: This logic depends on the method used for watermark embedding
|
27 |
-
# Here, assuming watermark is embedded as a specific value in weights
|
28 |
-
watermark_value = 1.0 # Example watermark value
|
29 |
|
30 |
-
#
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
34 |
|
35 |
-
|
|
|
|
|
36 |
|
37 |
-
|
38 |
-
if __name__ == "__main__":
|
39 |
-
# Load your trained model and test data
|
40 |
-
# Example: Load model and test data
|
41 |
-
model = tf.keras.models.load_model('path_to_your_model')
|
42 |
-
test_data = np.random.random((100, 10)) # Example test data
|
43 |
|
44 |
-
|
45 |
-
|
|
|
|
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# watermarking_functions.py
|
2 |
|
|
|
3 |
import numpy as np
|
4 |
+
import hashlib
|
5 |
+
import random
|
6 |
+
import secrets
|
7 |
|
8 |
+
# Function to embed a watermark into the model using LSB technique
|
9 |
+
def embed_watermark_LSB(model, watermark_data):
|
10 |
+
"""
|
11 |
+
Embeds a watermark into the provided model using Least Significant Bit (LSB) technique.
|
12 |
|
13 |
+
Arguments:
|
14 |
+
model : object
|
15 |
+
The machine learning model object (e.g., TensorFlow/Keras model).
|
16 |
+
watermark_data : str
|
17 |
+
The watermark data to be embedded into the model.
|
18 |
|
19 |
+
Returns:
|
20 |
+
model : object
|
21 |
+
The model with the embedded watermark.
|
22 |
+
"""
|
23 |
|
24 |
+
# Convert watermark data to bytes
|
25 |
+
watermark_bytes = watermark_data.encode('utf-8')
|
26 |
|
27 |
+
# Ensure the watermark is within the capacity of the model parameters
|
28 |
+
total_capacity = sum([np.prod(w.shape) for w in model.get_weights()])
|
29 |
+
required_capacity = len(watermark_bytes) * 8 # 8 bits per byte
|
30 |
+
if required_capacity > total_capacity:
|
31 |
+
raise ValueError("Watermark size exceeds model capacity")
|
32 |
|
33 |
+
# Flatten and concatenate all model parameters
|
34 |
+
flattened_weights = np.concatenate([w.flatten() for w in model.get_weights()])
|
|
|
|
|
|
|
35 |
|
36 |
+
# Embed watermark bits into the least significant bits of model parameters
|
37 |
+
watermark_bits = ''.join(format(byte, '08b') for byte in watermark_bytes)
|
38 |
+
watermark_bits += '1' # Adding stop bit
|
39 |
+
for i, bit in enumerate(watermark_bits):
|
40 |
+
flattened_weights[i] = (flattened_weights[i] & ~1) | int(bit)
|
41 |
|
42 |
+
# Reshape and update model parameters with embedded watermark
|
43 |
+
updated_weights = np.split(flattened_weights, [np.prod(w.shape) for w in model.get_weights()])
|
44 |
+
model.set_weights([w.reshape(s) for w, s in zip(updated_weights, [w.shape for w in model.get_weights()])])
|
45 |
|
46 |
+
return model
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
+
# Function to detect and extract the watermark from the model using LSB detection
|
49 |
+
def detect_watermark_LSB(model):
|
50 |
+
"""
|
51 |
+
Detects and extracts the watermark from the provided model using Least Significant Bit (LSB) technique.
|
52 |
|
53 |
+
Arguments:
|
54 |
+
model : object
|
55 |
+
The machine learning model object (e.g., TensorFlow/Keras model).
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
detected_watermark : str or None
|
59 |
+
Extracted watermark if detected, else None.
|
60 |
+
"""
|
61 |
+
|
62 |
+
# Flatten and concatenate all model parameters
|
63 |
+
flattened_weights = np.concatenate([w.flatten() for w in model.get_weights()])
|
64 |
+
|
65 |
+
# Extract watermark bits from the least significant bits of model parameters
|
66 |
+
watermark_bits = ''
|
67 |
+
stop_bit = '1'
|
68 |
+
for bit in flattened_weights:
|
69 |
+
bit = int(bit) & 1
|
70 |
+
watermark_bits += str(bit)
|
71 |
+
if watermark_bits.endswith(stop_bit):
|
72 |
+
break
|
73 |
+
|
74 |
+
# Convert extracted bits to bytes and decode watermark
|
75 |
+
watermark_bytes = [int(watermark_bits[i:i+8], 2) for i in range(0, len(watermark_bits), 8)]
|
76 |
+
detected_watermark = bytearray(watermark_bytes).decode('utf-8')
|
77 |
+
|
78 |
+
return detected_watermark
|