Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files- app.py +300 -173
- detectability.py +40 -59
- paraphraser.py +65 -65
- sankey.py +86 -0
app.py
CHANGED
@@ -1,11 +1,8 @@
|
|
1 |
import nltk
|
2 |
nltk.download('stopwords')
|
3 |
-
# from transformers import AutoTokenizer
|
4 |
-
# from transformers import AutoModelForSeq2SeqLM
|
5 |
-
import plotly.graph_objs as go
|
6 |
-
from transformers import pipeline
|
7 |
import random
|
8 |
import gradio as gr
|
|
|
9 |
from tree import generate_subplot1, generate_subplot2
|
10 |
from paraphraser import generate_paraphrase
|
11 |
from lcs import find_common_subsequences, find_common_gram_positions
|
@@ -17,196 +14,326 @@ from detectability import SentenceDetectabilityCalculator
|
|
17 |
from distortion import SentenceDistortionCalculator
|
18 |
from euclidean_distance import SentenceEuclideanDistanceCalculator
|
19 |
from threeD_plot import gen_three_D_plot
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
def
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
masked_words.append(words)
|
50 |
-
masked_logits.append(logits)
|
51 |
-
|
52 |
-
sampled_sentences = []
|
53 |
-
for masked_sent, words, logits in zip(masked_sentences, masked_words, masked_logits):
|
54 |
-
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='inverse_transform', temperature=1.0))
|
55 |
-
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='exponential_minimum', temperature=1.0))
|
56 |
-
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='temperature', temperature=1.0))
|
57 |
-
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='greedy', temperature=1.0))
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
colors = ["red", "blue", "brown", "green"]
|
63 |
-
|
64 |
-
def select_color():
|
65 |
-
return random.choice(colors)
|
66 |
-
|
67 |
-
highlight_info = [(word, select_color()) for _, word in common_grams]
|
68 |
-
|
69 |
-
highlighted_user_prompt = highlight_common_words(common_grams, [user_prompt], "Non-melting Points in the User Prompt")
|
70 |
-
highlighted_accepted_sentences = highlight_common_words_dict(common_grams, selected_sentences, "Paraphrased Sentences")
|
71 |
-
highlighted_discarded_sentences = highlight_common_words_dict(common_grams, discarded_sentences, "Discarded Sentences")
|
72 |
-
|
73 |
-
trees1 = []
|
74 |
-
trees2 = []
|
75 |
-
|
76 |
-
masked_index = 0
|
77 |
-
sampled_index = 0
|
78 |
-
|
79 |
-
for i, sentence in enumerate(paraphrased_sentences):
|
80 |
-
next_masked_sentences = masked_sentences[masked_index:masked_index + 3]
|
81 |
-
next_sampled_sentences = sampled_sentences[sampled_index:sampled_index + 12]
|
82 |
-
|
83 |
-
tree1 = generate_subplot1(sentence, next_masked_sentences, highlight_info, common_grams)
|
84 |
-
trees1.append(tree1)
|
85 |
-
|
86 |
-
tree2 = generate_subplot2(next_masked_sentences, next_sampled_sentences, highlight_info, common_grams)
|
87 |
-
trees2.append(tree2)
|
88 |
-
|
89 |
-
masked_index += 3
|
90 |
-
sampled_index += 12
|
91 |
-
|
92 |
-
reparaphrased_sentences = generate_paraphrase(sampled_sentences)
|
93 |
-
|
94 |
-
len_reparaphrased_sentences = len(reparaphrased_sentences)
|
95 |
-
|
96 |
-
reparaphrased_sentences_list = []
|
97 |
-
|
98 |
-
# Process the sentences in batches of 10
|
99 |
-
for i in range(0, len_reparaphrased_sentences, 10):
|
100 |
-
# Get the current batch of 10 sentences
|
101 |
-
batch = reparaphrased_sentences[i:i + 10]
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
-
|
152 |
-
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
-
|
159 |
-
|
|
|
160 |
|
161 |
-
with gr.Row():
|
162 |
with gr.Tabs():
|
163 |
-
with gr.TabItem("Paraphrased Sentences"):
|
164 |
highlighted_accepted_sentences = gr.HTML()
|
165 |
-
with gr.TabItem("Discarded Sentences"):
|
166 |
highlighted_discarded_sentences = gr.HTML()
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
gr.
|
171 |
-
|
172 |
with gr.Tabs():
|
173 |
tree1_tabs = []
|
174 |
-
for i in range(10):
|
175 |
-
with gr.TabItem(f"Sentence {i+1}"):
|
176 |
tree1 = gr.Plot()
|
177 |
tree1_tabs.append(tree1)
|
178 |
-
|
179 |
-
|
180 |
-
gr.Markdown("
|
181 |
-
|
|
|
182 |
with gr.Tabs():
|
183 |
tree2_tabs = []
|
184 |
-
for i in range(10):
|
185 |
-
with gr.TabItem(f"Sentence {i+1}"):
|
186 |
tree2 = gr.Plot()
|
187 |
tree2_tabs.append(tree2)
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
gr.
|
192 |
-
|
193 |
-
# Adding tabs for the re-paraphrased sentences
|
194 |
-
with gr.Row():
|
195 |
with gr.Tabs():
|
196 |
reparaphrased_sentences_tabs = []
|
197 |
-
for i in range(120):
|
198 |
-
with gr.TabItem(f"
|
199 |
-
reparaphrased_sent_html = gr.HTML()
|
200 |
reparaphrased_sentences_tabs.append(reparaphrased_sent_html)
|
201 |
-
|
202 |
-
|
203 |
-
gr.Markdown("
|
204 |
-
|
|
|
205 |
three_D_plot = gr.Plot()
|
|
|
206 |
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import nltk
|
2 |
nltk.download('stopwords')
|
|
|
|
|
|
|
|
|
3 |
import random
|
4 |
import gradio as gr
|
5 |
+
import time
|
6 |
from tree import generate_subplot1, generate_subplot2
|
7 |
from paraphraser import generate_paraphrase
|
8 |
from lcs import find_common_subsequences, find_common_gram_positions
|
|
|
14 |
from distortion import SentenceDistortionCalculator
|
15 |
from euclidean_distance import SentenceEuclideanDistanceCalculator
|
16 |
from threeD_plot import gen_three_D_plot
|
17 |
+
from sankey import generate_sankey_diagram
|
18 |
+
|
19 |
+
class WatermarkingPipeline:
|
20 |
+
def __init__(self):
|
21 |
+
# Existing initialization code...
|
22 |
+
self.user_prompt = None
|
23 |
+
self.paraphrased_sentences = None
|
24 |
+
self.analyzed_paraphrased_sentences = None
|
25 |
+
self.selected_sentences = None
|
26 |
+
self.discarded_sentences = None
|
27 |
+
self.common_grams = None
|
28 |
+
self.subsequences = None
|
29 |
+
self.common_grams_position = None
|
30 |
+
self.masked_sentences = None
|
31 |
+
self.masked_words = None
|
32 |
+
self.masked_logits = None
|
33 |
+
self.sampled_sentences = None
|
34 |
+
self.reparaphrased_sentences = None
|
35 |
+
self.distortion_list = None
|
36 |
+
self.detectability_list = None
|
37 |
+
self.euclidean_dist_list = None
|
38 |
+
|
39 |
+
def step1_paraphrasing(self, prompt, threshold=0.7):
|
40 |
+
start_time = time.time()
|
41 |
|
42 |
+
# Existing step1 code...
|
43 |
+
self.user_prompt = prompt
|
44 |
+
self.paraphrased_sentences = generate_paraphrase(prompt)
|
45 |
+
if self.paraphrased_sentences is None:
|
46 |
+
return "Error in generating paraphrases", "Error: Could not complete step"
|
47 |
|
48 |
+
self.analyzed_paraphrased_sentences, self.selected_sentences, self.discarded_sentences = \
|
49 |
+
analyze_entailment(self.user_prompt, self.paraphrased_sentences, threshold)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
+
self.common_grams = find_common_subsequences(self.user_prompt, self.selected_sentences)
|
52 |
+
self.subsequences = [subseq for _, subseq in self.common_grams]
|
53 |
+
self.common_grams_position = find_common_gram_positions(self.selected_sentences, self.subsequences)
|
54 |
+
|
55 |
+
colors = ["red", "blue", "brown", "green"]
|
56 |
+
def select_color():
|
57 |
+
return random.choice(colors)
|
58 |
+
highlight_info = [(word, select_color()) for _, word in self.common_grams]
|
59 |
+
|
60 |
+
highlighted_user_prompt = highlight_common_words(
|
61 |
+
self.common_grams, [self.user_prompt], "Highlighted LCS in the User Prompt"
|
62 |
+
)
|
63 |
+
highlighted_accepted_sentences = highlight_common_words_dict(
|
64 |
+
self.common_grams, self.selected_sentences, "Paraphrased Sentences"
|
65 |
+
)
|
66 |
+
highlighted_discarded_sentences = highlight_common_words_dict(
|
67 |
+
self.common_grams, self.discarded_sentences, "Discarded Sentences"
|
68 |
+
)
|
69 |
+
|
70 |
+
execution_time = time.time() - start_time
|
71 |
+
time_info = f"Step 1 completed in {execution_time:.2f} seconds"
|
72 |
+
|
73 |
+
return [
|
74 |
+
highlighted_user_prompt,
|
75 |
+
highlighted_accepted_sentences,
|
76 |
+
highlighted_discarded_sentences,
|
77 |
+
time_info
|
78 |
+
]
|
79 |
+
|
80 |
+
def step2_masking(self):
|
81 |
+
start_time = time.time()
|
82 |
+
|
83 |
+
if self.paraphrased_sentences is None:
|
84 |
+
return [None] * 10 + ["Error: Please complete step 1 first"]
|
85 |
+
|
86 |
+
# Existing step2 code...
|
87 |
+
self.masked_sentences = []
|
88 |
+
self.masked_words = []
|
89 |
+
self.masked_logits = []
|
90 |
+
|
91 |
+
for sentence in self.paraphrased_sentences:
|
92 |
+
for mask_func in [mask_non_stopword, mask_non_stopword_pseudorandom,
|
93 |
+
lambda s: high_entropy_words(s, self.common_grams)]:
|
94 |
+
masked_sent, logits, words = mask_func(sentence)
|
95 |
+
self.masked_sentences.append(masked_sent)
|
96 |
+
self.masked_words.append(words)
|
97 |
+
self.masked_logits.append(logits)
|
98 |
+
|
99 |
+
trees = []
|
100 |
+
masked_index = 0
|
101 |
+
colors = ["red", "blue", "brown", "green"]
|
102 |
+
highlight_info = [(word, random.choice(colors)) for _, word in self.common_grams]
|
103 |
+
|
104 |
+
for i, sentence in enumerate(self.paraphrased_sentences):
|
105 |
+
next_masked = self.masked_sentences[masked_index:masked_index + 3]
|
106 |
+
tree = generate_subplot1(sentence, next_masked, highlight_info, self.common_grams)
|
107 |
+
trees.append(tree)
|
108 |
+
masked_index += 3
|
109 |
+
|
110 |
+
execution_time = time.time() - start_time
|
111 |
+
time_info = f"Step 2 completed in {execution_time:.2f} seconds"
|
112 |
+
|
113 |
+
return trees + [time_info]
|
114 |
|
115 |
+
def step3_sampling(self):
|
116 |
+
start_time = time.time()
|
117 |
+
|
118 |
+
if self.masked_sentences is None:
|
119 |
+
return [None] * 10 + ["Error: Please complete step 2 first"]
|
120 |
+
|
121 |
+
# Existing step3 code...
|
122 |
+
self.sampled_sentences = []
|
123 |
+
trees = []
|
124 |
+
colors = ["red", "blue", "brown", "green"]
|
125 |
+
highlight_info = [(word, random.choice(colors)) for _, word in self.common_grams]
|
126 |
+
|
127 |
+
sampling_techniques = [
|
128 |
+
('inverse_transform', 1.0),
|
129 |
+
('exponential_minimum', 1.0),
|
130 |
+
('temperature', 1.0),
|
131 |
+
('greedy', 1.0)
|
132 |
+
]
|
133 |
+
|
134 |
+
masked_index = 0
|
135 |
+
while masked_index < len(self.masked_sentences):
|
136 |
+
current_masked = self.masked_sentences[masked_index:masked_index + 3]
|
137 |
+
current_words = self.masked_words[masked_index:masked_index + 3]
|
138 |
+
current_logits = self.masked_logits[masked_index:masked_index + 3]
|
139 |
+
|
140 |
+
batch_samples = []
|
141 |
+
for masked_sent, words, logits in zip(current_masked, current_words, current_logits):
|
142 |
+
for technique, temp in sampling_techniques:
|
143 |
+
sampled = sample_word(masked_sent, words, logits,
|
144 |
+
sampling_technique=technique,
|
145 |
+
temperature=temp)
|
146 |
+
batch_samples.append(sampled)
|
147 |
+
|
148 |
+
self.sampled_sentences.extend(batch_samples)
|
149 |
+
|
150 |
+
if current_masked:
|
151 |
+
tree = generate_subplot2(
|
152 |
+
current_masked,
|
153 |
+
batch_samples,
|
154 |
+
highlight_info,
|
155 |
+
self.common_grams
|
156 |
+
)
|
157 |
+
trees.append(tree)
|
158 |
+
|
159 |
+
masked_index += 3
|
160 |
+
|
161 |
+
if len(trees) < 10:
|
162 |
+
trees.extend([None] * (10 - len(trees)))
|
163 |
+
|
164 |
+
execution_time = time.time() - start_time
|
165 |
+
time_info = f"Step 3 completed in {execution_time:.2f} seconds"
|
166 |
+
|
167 |
+
return trees[:10] + [time_info]
|
168 |
+
|
169 |
+
def step4_reparaphrase(self):
|
170 |
+
start_time = time.time()
|
171 |
+
|
172 |
+
if self.sampled_sentences is None:
|
173 |
+
return ["Error: Please complete step 3 first"] * 120 + ["Error: Please complete step 3 first"]
|
174 |
+
|
175 |
+
# Existing step4 code...
|
176 |
+
self.reparaphrased_sentences = []
|
177 |
+
for i in range(13):
|
178 |
+
self.reparaphrased_sentences.append(generate_paraphrase(self.sampled_sentences[i]))
|
179 |
+
|
180 |
+
reparaphrased_sentences_list = []
|
181 |
+
for i in range(0, len(self.reparaphrased_sentences), 10):
|
182 |
+
batch = self.reparaphrased_sentences[i:i + 10]
|
183 |
+
if len(batch) == 10:
|
184 |
+
html_block = reparaphrased_sentences_html(batch)
|
185 |
+
reparaphrased_sentences_list.append(html_block)
|
186 |
+
|
187 |
+
execution_time = time.time() - start_time
|
188 |
+
time_info = f"Step 4 completed in {execution_time:.2f} seconds"
|
189 |
+
|
190 |
+
return reparaphrased_sentences_list + [time_info]
|
191 |
+
|
192 |
+
def step5_metrics(self):
|
193 |
+
start_time = time.time()
|
194 |
+
|
195 |
+
if self.reparaphrased_sentences is None:
|
196 |
+
return "Please complete step 4 first", "Error: Please complete step 4 first"
|
197 |
+
|
198 |
+
# Existing step5 code...
|
199 |
+
distortion_calculator = SentenceDistortionCalculator(self.user_prompt, self.reparaphrased_sentences)
|
200 |
+
distortion_calculator.calculate_all_metrics()
|
201 |
+
distortion_calculator.normalize_metrics()
|
202 |
+
distortion_calculator.calculate_combined_distortion()
|
203 |
+
distortion = distortion_calculator.get_combined_distortions()
|
204 |
+
self.distortion_list = [each[1] for each in distortion.items()]
|
205 |
+
|
206 |
+
detectability_calculator = SentenceDetectabilityCalculator(self.user_prompt, self.reparaphrased_sentences)
|
207 |
+
detectability_calculator.calculate_all_metrics()
|
208 |
+
detectability_calculator.normalize_metrics()
|
209 |
+
detectability_calculator.calculate_combined_detectability()
|
210 |
+
detectability = detectability_calculator.get_combined_detectabilities()
|
211 |
+
self.detectability_list = [each[1] for each in detectability.items()]
|
212 |
+
|
213 |
+
euclidean_dist_calculator = SentenceEuclideanDistanceCalculator(self.user_prompt, self.reparaphrased_sentences)
|
214 |
+
euclidean_dist_calculator.calculate_all_metrics()
|
215 |
+
euclidean_dist_calculator.normalize_metrics()
|
216 |
+
euclidean_dist = detectability_calculator.get_combined_detectabilities()
|
217 |
+
self.euclidean_dist_list = [each[1] for each in euclidean_dist.items()]
|
218 |
+
|
219 |
+
three_D_plot = gen_three_D_plot(
|
220 |
+
self.detectability_list,
|
221 |
+
self.distortion_list,
|
222 |
+
self.euclidean_dist_list
|
223 |
+
)
|
224 |
+
|
225 |
+
execution_time = time.time() - start_time
|
226 |
+
time_info = f"Step 5 completed in {execution_time:.2f} seconds"
|
227 |
+
|
228 |
+
return three_D_plot, time_info
|
229 |
|
230 |
+
def step6_sankey(self):
|
231 |
+
return generate_sankey_diagram()
|
232 |
|
233 |
+
def create_gradio_interface():
|
234 |
+
pipeline = WatermarkingPipeline()
|
235 |
+
|
236 |
+
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
237 |
+
gr.Markdown("# **AIISC Watermarking Model**")
|
238 |
+
|
239 |
+
with gr.Column():
|
240 |
+
gr.Markdown("## Input Prompt")
|
241 |
+
user_input = gr.Textbox(label="Enter Your Prompt")
|
242 |
|
243 |
+
gr.Markdown("## Step 1: Paraphrasing, LCS and Entailment Analysis")
|
244 |
+
paraphrase_button = gr.Button("Generate Paraphrases")
|
245 |
+
highlighted_user_prompt = gr.HTML(label="Highlighted User Prompt")
|
246 |
|
|
|
247 |
with gr.Tabs():
|
248 |
+
with gr.TabItem("Accepted Paraphrased Sentences"):
|
249 |
highlighted_accepted_sentences = gr.HTML()
|
250 |
+
with gr.TabItem("Discarded Paraphrased Sentences"):
|
251 |
highlighted_discarded_sentences = gr.HTML()
|
252 |
+
step1_time = gr.Textbox(label="Execution Time", interactive=False)
|
253 |
+
|
254 |
+
gr.Markdown("## Step 2: Where to Mask?")
|
255 |
+
masking_button = gr.Button("Apply Masking")
|
256 |
+
gr.Markdown("### Masked Sentence Trees")
|
257 |
with gr.Tabs():
|
258 |
tree1_tabs = []
|
259 |
+
for i in range(10):
|
260 |
+
with gr.TabItem(f"Masked Sentence {i+1}"):
|
261 |
tree1 = gr.Plot()
|
262 |
tree1_tabs.append(tree1)
|
263 |
+
step2_time = gr.Textbox(label="Execution Time", interactive=False)
|
264 |
+
|
265 |
+
gr.Markdown("## Step 3: How to Mask?")
|
266 |
+
sampling_button = gr.Button("Sample Words")
|
267 |
+
gr.Markdown("### Sampled Sentence Trees")
|
268 |
with gr.Tabs():
|
269 |
tree2_tabs = []
|
270 |
+
for i in range(10):
|
271 |
+
with gr.TabItem(f"Sampled Sentence {i+1}"):
|
272 |
tree2 = gr.Plot()
|
273 |
tree2_tabs.append(tree2)
|
274 |
+
step3_time = gr.Textbox(label="Execution Time", interactive=False)
|
275 |
+
|
276 |
+
gr.Markdown("## Step 4: Re-paraphrasing")
|
277 |
+
reparaphrase_button = gr.Button("Re-paraphrase")
|
278 |
+
gr.Markdown("### Reparaphrased Sentences")
|
|
|
|
|
279 |
with gr.Tabs():
|
280 |
reparaphrased_sentences_tabs = []
|
281 |
+
for i in range(120):
|
282 |
+
with gr.TabItem(f"Reparaphrased Batch {i+1}"):
|
283 |
+
reparaphrased_sent_html = gr.HTML()
|
284 |
reparaphrased_sentences_tabs.append(reparaphrased_sent_html)
|
285 |
+
step4_time = gr.Textbox(label="Execution Time", interactive=False)
|
286 |
+
|
287 |
+
gr.Markdown("## Step 5: Finding Sweet Spot")
|
288 |
+
metrics_button = gr.Button("Calculate Metrics")
|
289 |
+
gr.Markdown("### 3D Visualization of Metrics")
|
290 |
three_D_plot = gr.Plot()
|
291 |
+
step5_time = gr.Textbox(label="Execution Time", interactive=False)
|
292 |
|
293 |
+
# Sankey Diagram
|
294 |
+
gr.Markdown("# Watermarking Pipeline Flow Visualization")
|
295 |
+
generate_button = gr.Button("Generate Sankey Diagram")
|
296 |
+
sankey_plot = gr.Plot()
|
297 |
+
|
298 |
+
paraphrase_button.click(
|
299 |
+
pipeline.step1_paraphrasing,
|
300 |
+
inputs=user_input,
|
301 |
+
outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences, step1_time]
|
302 |
+
)
|
303 |
+
|
304 |
+
masking_button.click(
|
305 |
+
pipeline.step2_masking,
|
306 |
+
inputs=None,
|
307 |
+
outputs=tree1_tabs + [step2_time]
|
308 |
+
)
|
309 |
+
|
310 |
+
sampling_button.click(
|
311 |
+
pipeline.step3_sampling,
|
312 |
+
inputs=None,
|
313 |
+
outputs=tree2_tabs + [step3_time],
|
314 |
+
show_progress=True
|
315 |
+
)
|
316 |
+
|
317 |
+
reparaphrase_button.click(
|
318 |
+
pipeline.step4_reparaphrase,
|
319 |
+
inputs=None,
|
320 |
+
outputs=reparaphrased_sentences_tabs + [step4_time]
|
321 |
+
)
|
322 |
+
|
323 |
+
metrics_button.click(
|
324 |
+
pipeline.step5_metrics,
|
325 |
+
inputs=None,
|
326 |
+
outputs=[three_D_plot, step5_time]
|
327 |
+
)
|
328 |
+
|
329 |
+
generate_button.click(
|
330 |
+
pipeline.step6_sankey,
|
331 |
+
inputs=None,
|
332 |
+
outputs=sankey_plot
|
333 |
+
)
|
334 |
+
|
335 |
+
return demo
|
336 |
+
|
337 |
+
if __name__ == "__main__":
|
338 |
+
demo = create_gradio_interface()
|
339 |
+
demo.launch(share=True)
|
detectability.py
CHANGED
@@ -6,12 +6,12 @@ import torch
|
|
6 |
import matplotlib.pyplot as plt
|
7 |
from sklearn.metrics.pairwise import cosine_similarity
|
8 |
from transformers import BertModel, BertTokenizer
|
9 |
-
from
|
10 |
-
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
11 |
|
12 |
# Download NLTK data if not already present
|
13 |
nltk.download('punkt', quiet=True)
|
14 |
-
detectability_val={}
|
|
|
15 |
class SentenceDetectabilityCalculator:
|
16 |
"""
|
17 |
A class to calculate and analyze detectability metrics between an original sentence and paraphrased sentences.
|
@@ -25,63 +25,62 @@ class SentenceDetectabilityCalculator:
|
|
25 |
self.paraphrased_sentences = paraphrased_sentences
|
26 |
|
27 |
# Raw metric dictionaries
|
28 |
-
self.
|
29 |
-
self.
|
30 |
-
self.
|
31 |
|
32 |
# Normalized metric dictionaries
|
33 |
-
self.
|
34 |
-
self.
|
35 |
-
self.normalized_sts = {}
|
36 |
|
37 |
# Combined detectability dictionary
|
38 |
self.combined_detectabilities = {}
|
39 |
|
40 |
-
# Load pre-trained BERT
|
41 |
self.bert_model = BertModel.from_pretrained('bert-base-uncased')
|
42 |
self.bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
43 |
-
self.sts_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
|
44 |
|
45 |
def calculate_all_metrics(self):
|
46 |
"""
|
47 |
-
Calculate
|
48 |
"""
|
49 |
original_embedding = self._get_sentence_embedding(self.original_sentence)
|
50 |
-
sts_original_embedding = self.sts_model.encode(self.original_sentence)
|
51 |
|
|
|
52 |
for idx, paraphrased_sentence in enumerate(self.paraphrased_sentences):
|
53 |
-
key = f"Sentence_{idx+1}"
|
54 |
-
|
55 |
-
# BLEU Score
|
56 |
-
self.bleu_scores[key] = self._calculate_bleu(self.original_sentence, paraphrased_sentence)
|
57 |
-
|
58 |
-
# Cosine Similarity
|
59 |
paraphrase_embedding = self._get_sentence_embedding(paraphrased_sentence)
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
65 |
|
66 |
def normalize_metrics(self):
|
67 |
"""
|
68 |
-
Normalize
|
69 |
"""
|
70 |
-
self.
|
71 |
-
self.
|
72 |
-
self.normalized_sts = self._normalize_dict(self.sts_scores)
|
73 |
|
74 |
def calculate_combined_detectability(self):
|
75 |
"""
|
76 |
Calculate the combined detectability using the root mean square of the normalized metrics.
|
77 |
"""
|
78 |
-
for key in self.
|
79 |
rms = np.sqrt(
|
80 |
(
|
81 |
-
self.
|
82 |
-
self.
|
83 |
-
|
84 |
-
) / 3
|
85 |
)
|
86 |
self.combined_detectabilities[key] = rms
|
87 |
|
@@ -89,14 +88,13 @@ class SentenceDetectabilityCalculator:
|
|
89 |
"""
|
90 |
Plot each normalized metric and the combined detectability in separate graphs.
|
91 |
"""
|
92 |
-
keys = list(self.
|
93 |
indices = np.arange(len(keys))
|
94 |
|
95 |
# Prepare data for plotting
|
96 |
metrics = {
|
97 |
-
'
|
98 |
-
'
|
99 |
-
'STS Score': [self.normalized_sts[key] for key in keys],
|
100 |
'Combined Detectability': [self.combined_detectabilities[key] for key in keys]
|
101 |
}
|
102 |
|
@@ -111,16 +109,7 @@ class SentenceDetectabilityCalculator:
|
|
111 |
plt.tight_layout()
|
112 |
plt.show()
|
113 |
|
114 |
-
# Private methods
|
115 |
-
def _calculate_bleu(self, reference, candidate):
|
116 |
-
"""
|
117 |
-
Calculate the BLEU score between the original and paraphrased sentence using smoothing.
|
118 |
-
"""
|
119 |
-
reference_tokens = nltk.word_tokenize(reference)
|
120 |
-
candidate_tokens = nltk.word_tokenize(candidate)
|
121 |
-
smoothing = SmoothingFunction().method1
|
122 |
-
return sentence_bleu([reference_tokens], candidate_tokens, smoothing_function=smoothing)
|
123 |
-
|
124 |
def _get_sentence_embedding(self, sentence):
|
125 |
"""
|
126 |
Get sentence embedding using BERT.
|
@@ -150,9 +139,8 @@ class SentenceDetectabilityCalculator:
|
|
150 |
Get all normalized metrics as a dictionary.
|
151 |
"""
|
152 |
return {
|
153 |
-
'
|
154 |
-
'
|
155 |
-
'STS Score': self.normalized_sts
|
156 |
}
|
157 |
|
158 |
def get_combined_detectabilities(self):
|
@@ -310,7 +298,6 @@ if __name__ == "__main__":
|
|
310 |
"Final observation: Red subject shows mobility over Gray subject."
|
311 |
]
|
312 |
|
313 |
-
|
314 |
# Initialize the calculator
|
315 |
calculator = SentenceDetectabilityCalculator(original_sentence, paraphrased_sentences)
|
316 |
|
@@ -326,18 +313,12 @@ if __name__ == "__main__":
|
|
326 |
# Retrieve the normalized metrics and combined detectabilities
|
327 |
normalized_metrics = calculator.get_normalized_metrics()
|
328 |
combined_detectabilities = calculator.get_combined_detectabilities()
|
329 |
-
detectability_val=combined_detectabilities
|
330 |
|
331 |
# Display the results
|
332 |
-
# print("Normalized Metrics:")
|
333 |
-
# for metric_name, metric_dict in normalized_metrics.items():
|
334 |
-
# print(f"\n{metric_name}:")
|
335 |
-
# for key, value in metric_dict.items():
|
336 |
-
# print(f"{key}: {value:.4f}")
|
337 |
-
|
338 |
print("\nCombined Detectabilities:")
|
339 |
for each in combined_detectabilities.items():
|
340 |
print(f"{each[1]}")
|
341 |
|
342 |
-
# Plot the metrics
|
343 |
-
#
|
|
|
6 |
import matplotlib.pyplot as plt
|
7 |
from sklearn.metrics.pairwise import cosine_similarity
|
8 |
from transformers import BertModel, BertTokenizer
|
9 |
+
from scipy import stats
|
|
|
10 |
|
11 |
# Download NLTK data if not already present
|
12 |
nltk.download('punkt', quiet=True)
|
13 |
+
detectability_val = {}
|
14 |
+
|
15 |
class SentenceDetectabilityCalculator:
|
16 |
"""
|
17 |
A class to calculate and analyze detectability metrics between an original sentence and paraphrased sentences.
|
|
|
25 |
self.paraphrased_sentences = paraphrased_sentences
|
26 |
|
27 |
# Raw metric dictionaries
|
28 |
+
self.z_scores = {}
|
29 |
+
self.p_values = {}
|
30 |
+
self.metric_values = []
|
31 |
|
32 |
# Normalized metric dictionaries
|
33 |
+
self.normalized_z_scores = {}
|
34 |
+
self.normalized_p_values = {}
|
|
|
35 |
|
36 |
# Combined detectability dictionary
|
37 |
self.combined_detectabilities = {}
|
38 |
|
39 |
+
# Load pre-trained BERT for embeddings
|
40 |
self.bert_model = BertModel.from_pretrained('bert-base-uncased')
|
41 |
self.bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
|
|
42 |
|
43 |
def calculate_all_metrics(self):
|
44 |
"""
|
45 |
+
Calculate detectability metrics for each paraphrased sentence.
|
46 |
"""
|
47 |
original_embedding = self._get_sentence_embedding(self.original_sentence)
|
|
|
48 |
|
49 |
+
# First, compute the metric values (cosine similarities)
|
50 |
for idx, paraphrased_sentence in enumerate(self.paraphrased_sentences):
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
paraphrase_embedding = self._get_sentence_embedding(paraphrased_sentence)
|
52 |
+
cosine_sim = cosine_similarity([original_embedding], [paraphrase_embedding])[0][0]
|
53 |
+
self.metric_values.append(cosine_sim)
|
54 |
+
|
55 |
+
# Compute mean and standard deviation of the metric values
|
56 |
+
metric_mean = np.mean(self.metric_values)
|
57 |
+
metric_std = np.std(self.metric_values)
|
58 |
|
59 |
+
# Compute z-scores and p-values
|
60 |
+
for idx, (paraphrased_sentence, metric_value) in enumerate(zip(self.paraphrased_sentences, self.metric_values)):
|
61 |
+
key = f"Sentence_{idx+1}"
|
62 |
+
z_score = (metric_value - metric_mean) / metric_std if metric_std != 0 else 0.0
|
63 |
+
p_value = stats.norm.sf(abs(z_score)) * 2 # two-tailed p-value
|
64 |
+
self.z_scores[key] = z_score
|
65 |
+
self.p_values[key] = p_value
|
66 |
|
67 |
def normalize_metrics(self):
|
68 |
"""
|
69 |
+
Normalize z-scores and p-values to be between 0 and 1.
|
70 |
"""
|
71 |
+
self.normalized_z_scores = self._normalize_dict(self.z_scores)
|
72 |
+
self.normalized_p_values = self._normalize_dict(self.p_values)
|
|
|
73 |
|
74 |
def calculate_combined_detectability(self):
|
75 |
"""
|
76 |
Calculate the combined detectability using the root mean square of the normalized metrics.
|
77 |
"""
|
78 |
+
for key in self.normalized_z_scores.keys():
|
79 |
rms = np.sqrt(
|
80 |
(
|
81 |
+
self.normalized_z_scores[key] ** 2 +
|
82 |
+
self.normalized_p_values[key] ** 2
|
83 |
+
) / 2
|
|
|
84 |
)
|
85 |
self.combined_detectabilities[key] = rms
|
86 |
|
|
|
88 |
"""
|
89 |
Plot each normalized metric and the combined detectability in separate graphs.
|
90 |
"""
|
91 |
+
keys = list(self.normalized_z_scores.keys())
|
92 |
indices = np.arange(len(keys))
|
93 |
|
94 |
# Prepare data for plotting
|
95 |
metrics = {
|
96 |
+
'Z-Score': [self.normalized_z_scores[key] for key in keys],
|
97 |
+
'P-Value': [self.normalized_p_values[key] for key in keys],
|
|
|
98 |
'Combined Detectability': [self.combined_detectabilities[key] for key in keys]
|
99 |
}
|
100 |
|
|
|
109 |
plt.tight_layout()
|
110 |
plt.show()
|
111 |
|
112 |
+
# Private methods
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
def _get_sentence_embedding(self, sentence):
|
114 |
"""
|
115 |
Get sentence embedding using BERT.
|
|
|
139 |
Get all normalized metrics as a dictionary.
|
140 |
"""
|
141 |
return {
|
142 |
+
'Z-Score': self.normalized_z_scores,
|
143 |
+
'P-Value': self.normalized_p_values
|
|
|
144 |
}
|
145 |
|
146 |
def get_combined_detectabilities(self):
|
|
|
298 |
"Final observation: Red subject shows mobility over Gray subject."
|
299 |
]
|
300 |
|
|
|
301 |
# Initialize the calculator
|
302 |
calculator = SentenceDetectabilityCalculator(original_sentence, paraphrased_sentences)
|
303 |
|
|
|
313 |
# Retrieve the normalized metrics and combined detectabilities
|
314 |
normalized_metrics = calculator.get_normalized_metrics()
|
315 |
combined_detectabilities = calculator.get_combined_detectabilities()
|
316 |
+
detectability_val = combined_detectabilities
|
317 |
|
318 |
# Display the results
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
print("\nCombined Detectabilities:")
|
320 |
for each in combined_detectabilities.items():
|
321 |
print(f"{each[1]}")
|
322 |
|
323 |
+
# Plot the metrics (optional)
|
324 |
+
#calculator.plot_metrics()
|
paraphraser.py
CHANGED
@@ -1,32 +1,32 @@
|
|
1 |
-
|
2 |
|
3 |
-
#
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
|
9 |
-
#
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
|
31 |
# print(generate_paraphrase("Donald Trump said at a campaign rally event in Wilkes-Barre, Pennsylvania, that there has “never been a more dangerous time 5since the Holocaust” to be Jewish in the United States."))
|
32 |
|
@@ -34,50 +34,50 @@
|
|
34 |
Accepts a sentence or list of sentences and returns a lit of all their paraphrases using GPT-4.
|
35 |
'''
|
36 |
|
37 |
-
from openai import OpenAI
|
38 |
-
from dotenv import load_dotenv
|
39 |
-
load_dotenv()
|
40 |
-
import os
|
41 |
|
42 |
-
key = os.getenv("OPENAI_API_KEY")
|
43 |
|
44 |
-
# Initialize the OpenAI client
|
45 |
-
client = OpenAI(
|
46 |
-
|
47 |
-
)
|
48 |
|
49 |
-
# Function to paraphrase sentences using GPT-4
|
50 |
-
def generate_paraphrase(sentences, model="gpt-4o", num_paraphrases=10, max_tokens=150, temperature=0.7):
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
|
55 |
-
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
|
79 |
-
|
80 |
|
81 |
-
result = generate_paraphrase("Mayor Eric Adams did not attend the first candidate forum for the New York City mayoral race, but his record — and the criminal charges he faces — received plenty of attention on Saturday from the Democrats who are running to unseat him.")
|
82 |
|
83 |
-
print(len(result))
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
2 |
|
3 |
+
# Function to Initialize the Model
|
4 |
+
def init_model():
|
5 |
+
para_tokenizer = AutoTokenizer.from_pretrained("humarin/chatgpt_paraphraser_on_T5_base")
|
6 |
+
para_model = AutoModelForSeq2SeqLM.from_pretrained("humarin/chatgpt_paraphraser_on_T5_base")
|
7 |
+
return para_tokenizer, para_model
|
8 |
|
9 |
+
# Function to Paraphrase the Text
|
10 |
+
def paraphrase(question, para_tokenizer, para_model, num_beams=10, num_beam_groups=10, num_return_sequences=10, repetition_penalty=10.0, diversity_penalty=3.0, no_repeat_ngram_size=2, temperature=0.7, max_length=64):
|
11 |
+
input_ids = para_tokenizer(
|
12 |
+
f'paraphrase: {question}',
|
13 |
+
return_tensors="pt", padding="longest",
|
14 |
+
max_length=max_length,
|
15 |
+
truncation=True,
|
16 |
+
).input_ids
|
17 |
+
outputs = para_model.generate(
|
18 |
+
input_ids, temperature=temperature, repetition_penalty=repetition_penalty,
|
19 |
+
num_return_sequences=num_return_sequences, no_repeat_ngram_size=no_repeat_ngram_size,
|
20 |
+
num_beams=num_beams, num_beam_groups=num_beam_groups,
|
21 |
+
max_length=max_length, diversity_penalty=diversity_penalty
|
22 |
+
)
|
23 |
+
res = para_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
24 |
+
return res
|
25 |
|
26 |
+
def generate_paraphrase(question):
|
27 |
+
para_tokenizer, para_model = init_model()
|
28 |
+
res = paraphrase(question, para_tokenizer, para_model)
|
29 |
+
return res
|
30 |
|
31 |
# print(generate_paraphrase("Donald Trump said at a campaign rally event in Wilkes-Barre, Pennsylvania, that there has “never been a more dangerous time 5since the Holocaust” to be Jewish in the United States."))
|
32 |
|
|
|
34 |
Accepts a sentence or list of sentences and returns a lit of all their paraphrases using GPT-4.
|
35 |
'''
|
36 |
|
37 |
+
# from openai import OpenAI
|
38 |
+
# from dotenv import load_dotenv
|
39 |
+
# load_dotenv()
|
40 |
+
# import os
|
41 |
|
42 |
+
# key = os.getenv("OPENAI_API_KEY")
|
43 |
|
44 |
+
# # Initialize the OpenAI client
|
45 |
+
# client = OpenAI(
|
46 |
+
# api_key=key # Replace with your actual API key
|
47 |
+
# )
|
48 |
|
49 |
+
# # Function to paraphrase sentences using GPT-4
|
50 |
+
# def generate_paraphrase(sentences, model="gpt-4o", num_paraphrases=10, max_tokens=150, temperature=0.7):
|
51 |
+
# # Ensure sentences is a list even if a single sentence is passed
|
52 |
+
# if isinstance(sentences, str):
|
53 |
+
# sentences = [sentences]
|
54 |
|
55 |
+
# paraphrased_sentences_list = []
|
56 |
|
57 |
+
# for sentence in sentences:
|
58 |
+
# full_prompt = f"Paraphrase the following text: '{sentence}'"
|
59 |
+
# try:
|
60 |
+
# chat_completion = client.chat.completions.create(
|
61 |
+
# messages=[
|
62 |
+
# {
|
63 |
+
# "role": "user",
|
64 |
+
# "content": full_prompt,
|
65 |
+
# }
|
66 |
+
# ],
|
67 |
+
# model=model,
|
68 |
+
# max_tokens=max_tokens,
|
69 |
+
# temperature=temperature,
|
70 |
+
# n=num_paraphrases # Number of paraphrased sentences to generate
|
71 |
+
# )
|
72 |
+
# # Extract the paraphrased sentences from the response
|
73 |
+
# paraphrased_sentences = [choice.message.content.strip() for choice in chat_completion.choices]
|
74 |
+
# # Append paraphrased sentences to the list
|
75 |
+
# paraphrased_sentences_list.extend(paraphrased_sentences)
|
76 |
+
# except Exception as e:
|
77 |
+
# print(f"Error paraphrasing sentence '{sentence}': {e}")
|
78 |
|
79 |
+
# return paraphrased_sentences_list
|
80 |
|
81 |
+
# result = generate_paraphrase("Mayor Eric Adams did not attend the first candidate forum for the New York City mayoral race, but his record — and the criminal charges he faces — received plenty of attention on Saturday from the Democrats who are running to unseat him.")
|
82 |
|
83 |
+
# print(len(result))
|
sankey.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import plotly.graph_objects as go
|
2 |
+
|
3 |
+
def generate_sankey_diagram():
|
4 |
+
pipeline_metrics = {
|
5 |
+
'masking_methods': ['random masking', 'pseudorandom masking', 'high-entropy masking'],
|
6 |
+
'sampling_methods': ['inverse_transform sampling', 'exponential_minimum sampling', 'temperature sampling', 'greedy sampling'],
|
7 |
+
'scores': {
|
8 |
+
('random masking', 'inverse_transform sampling'): {'detectability': 0.8, 'distortion': 0.2},
|
9 |
+
('random masking', 'exponential_minimum sampling'): {'detectability': 0.7, 'distortion': 0.3},
|
10 |
+
('random masking', 'temperature sampling'): {'detectability': 0.6, 'distortion': 0.4},
|
11 |
+
('random masking', 'greedy sampling'): {'detectability': 0.5, 'distortion': 0.5},
|
12 |
+
('pseudorandom masking', 'inverse_transform sampling'): {'detectability': 0.75, 'distortion': 0.25},
|
13 |
+
('pseudorandom masking', 'exponential_minimum sampling'): {'detectability': 0.65, 'distortion': 0.35},
|
14 |
+
('pseudorandom masking', 'temperature sampling'): {'detectability': 0.55, 'distortion': 0.45},
|
15 |
+
('pseudorandom masking', 'greedy sampling'): {'detectability': 0.45, 'distortion': 0.55},
|
16 |
+
('high-entropy masking', 'inverse_transform sampling'): {'detectability': 0.85, 'distortion': 0.15},
|
17 |
+
('high-entropy masking', 'exponential_minimum sampling'): {'detectability': 0.75, 'distortion': 0.25},
|
18 |
+
('high-entropy masking', 'temperature sampling'): {'detectability': 0.65, 'distortion': 0.35},
|
19 |
+
('high-entropy masking', 'greedy sampling'): {'detectability': 0.55, 'distortion': 0.45}
|
20 |
+
}
|
21 |
+
}
|
22 |
+
|
23 |
+
# Find best combination
|
24 |
+
best_score = 0
|
25 |
+
best_combo = None
|
26 |
+
for combo, metrics in pipeline_metrics['scores'].items():
|
27 |
+
score = metrics['detectability'] * (1 - metrics['distortion'])
|
28 |
+
if score > best_score:
|
29 |
+
best_score = score
|
30 |
+
best_combo = combo
|
31 |
+
|
32 |
+
label_list = ['Input'] + pipeline_metrics['masking_methods'] + pipeline_metrics['sampling_methods'] + ['Output']
|
33 |
+
|
34 |
+
source = []
|
35 |
+
target = []
|
36 |
+
value = []
|
37 |
+
colors = []
|
38 |
+
|
39 |
+
# Input to masking methods
|
40 |
+
for i in range(len(pipeline_metrics['masking_methods'])):
|
41 |
+
source.append(0)
|
42 |
+
target.append(i + 1)
|
43 |
+
value.append(1)
|
44 |
+
colors.append('rgba(0,0,255,0.2)' if pipeline_metrics['masking_methods'][i] != best_combo[0] else 'rgba(255,0,0,0.8)')
|
45 |
+
|
46 |
+
# Masking to sampling methods
|
47 |
+
sampling_start = len(pipeline_metrics['masking_methods']) + 1
|
48 |
+
for i, mask in enumerate(pipeline_metrics['masking_methods']):
|
49 |
+
for j, sample in enumerate(pipeline_metrics['sampling_methods']):
|
50 |
+
score = pipeline_metrics['scores'][(mask, sample)]['detectability'] * \
|
51 |
+
(1 - pipeline_metrics['scores'][(mask, sample)]['distortion'])
|
52 |
+
source.append(i + 1)
|
53 |
+
target.append(sampling_start + j)
|
54 |
+
value.append(score)
|
55 |
+
colors.append('rgba(0,0,255,0.2)' if (mask, sample) != best_combo else 'rgba(255,0,0,0.8)')
|
56 |
+
|
57 |
+
# Sampling methods to output
|
58 |
+
output_idx = len(label_list) - 1
|
59 |
+
for i, sample in enumerate(pipeline_metrics['sampling_methods']):
|
60 |
+
source.append(sampling_start + i)
|
61 |
+
target.append(output_idx)
|
62 |
+
value.append(1)
|
63 |
+
colors.append('rgba(0,0,255,0.2)' if sample != best_combo[1] else 'rgba(255,0,0,0.8)')
|
64 |
+
|
65 |
+
fig = go.Figure(data=[go.Sankey(
|
66 |
+
node=dict(
|
67 |
+
pad=15,
|
68 |
+
thickness=20,
|
69 |
+
line=dict(color="black", width=0.5),
|
70 |
+
label=label_list,
|
71 |
+
color="lightblue"
|
72 |
+
),
|
73 |
+
link=dict(
|
74 |
+
source=source,
|
75 |
+
target=target,
|
76 |
+
value=value,
|
77 |
+
color=colors
|
78 |
+
)
|
79 |
+
)])
|
80 |
+
|
81 |
+
fig.update_layout(
|
82 |
+
title_text=f"Watermarking Pipeline Flow<br>Best Combination: {best_combo[0]} + {best_combo[1]}",
|
83 |
+
font_size=12,
|
84 |
+
height=500
|
85 |
+
)
|
86 |
+
return fig
|