Spaces:
Running
Running
import plotly.graph_objects as go | |
def generate_sankey_diagram(): | |
pipeline_metrics = { | |
'masking_methods': ['random masking', 'pseudorandom masking', 'high-entropy masking'], | |
'sampling_methods': ['inverse_transform sampling', 'exponential_minimum sampling', 'temperature sampling', 'greedy sampling'], | |
'scores': { | |
('random masking', 'inverse_transform sampling'): {'detectability': 0.8, 'distortion': 0.2}, | |
('random masking', 'exponential_minimum sampling'): {'detectability': 0.7, 'distortion': 0.3}, | |
('random masking', 'temperature sampling'): {'detectability': 0.6, 'distortion': 0.4}, | |
('random masking', 'greedy sampling'): {'detectability': 0.5, 'distortion': 0.5}, | |
('pseudorandom masking', 'inverse_transform sampling'): {'detectability': 0.75, 'distortion': 0.25}, | |
('pseudorandom masking', 'exponential_minimum sampling'): {'detectability': 0.65, 'distortion': 0.35}, | |
('pseudorandom masking', 'temperature sampling'): {'detectability': 0.55, 'distortion': 0.45}, | |
('pseudorandom masking', 'greedy sampling'): {'detectability': 0.45, 'distortion': 0.55}, | |
('high-entropy masking', 'inverse_transform sampling'): {'detectability': 0.85, 'distortion': 0.15}, | |
('high-entropy masking', 'exponential_minimum sampling'): {'detectability': 0.75, 'distortion': 0.25}, | |
('high-entropy masking', 'temperature sampling'): {'detectability': 0.65, 'distortion': 0.35}, | |
('high-entropy masking', 'greedy sampling'): {'detectability': 0.55, 'distortion': 0.45} | |
} | |
} | |
# Find best combination | |
best_score = 0 | |
best_combo = None | |
for combo, metrics in pipeline_metrics['scores'].items(): | |
score = metrics['detectability'] * (1 - metrics['distortion']) | |
if score > best_score: | |
best_score = score | |
best_combo = combo | |
label_list = ['Input'] + pipeline_metrics['masking_methods'] + pipeline_metrics['sampling_methods'] + ['Output'] | |
source = [] | |
target = [] | |
value = [] | |
colors = [] | |
# Input to masking methods | |
for i in range(len(pipeline_metrics['masking_methods'])): | |
source.append(0) | |
target.append(i + 1) | |
value.append(1) | |
colors.append('rgba(0,0,255,0.2)' if pipeline_metrics['masking_methods'][i] != best_combo[0] else 'rgba(255,0,0,0.8)') | |
# Masking to sampling methods | |
sampling_start = len(pipeline_metrics['masking_methods']) + 1 | |
for i, mask in enumerate(pipeline_metrics['masking_methods']): | |
for j, sample in enumerate(pipeline_metrics['sampling_methods']): | |
score = pipeline_metrics['scores'][(mask, sample)]['detectability'] * \ | |
(1 - pipeline_metrics['scores'][(mask, sample)]['distortion']) | |
source.append(i + 1) | |
target.append(sampling_start + j) | |
value.append(score) | |
colors.append('rgba(0,0,255,0.2)' if (mask, sample) != best_combo else 'rgba(255,0,0,0.8)') | |
# Sampling methods to output | |
output_idx = len(label_list) - 1 | |
for i, sample in enumerate(pipeline_metrics['sampling_methods']): | |
source.append(sampling_start + i) | |
target.append(output_idx) | |
value.append(1) | |
colors.append('rgba(0,0,255,0.2)' if sample != best_combo[1] else 'rgba(255,0,0,0.8)') | |
fig = go.Figure(data=[go.Sankey( | |
node=dict( | |
pad=15, | |
thickness=20, | |
line=dict(color="black", width=0.5), | |
label=label_list, | |
color="lightblue" | |
), | |
link=dict( | |
source=source, | |
target=target, | |
value=value, | |
color=colors | |
) | |
)]) | |
fig.update_layout( | |
title_text=f"Watermarking Pipeline Flow<br>Best Combination: {best_combo[0]} + {best_combo[1]}", | |
font_size=12, | |
height=500 | |
) | |
return fig |