Spaces:
Running
Running
File size: 3,942 Bytes
2471de4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import plotly.graph_objects as go
def generate_sankey_diagram():
pipeline_metrics = {
'masking_methods': ['random masking', 'pseudorandom masking', 'high-entropy masking'],
'sampling_methods': ['inverse_transform sampling', 'exponential_minimum sampling', 'temperature sampling', 'greedy sampling'],
'scores': {
('random masking', 'inverse_transform sampling'): {'detectability': 0.8, 'distortion': 0.2},
('random masking', 'exponential_minimum sampling'): {'detectability': 0.7, 'distortion': 0.3},
('random masking', 'temperature sampling'): {'detectability': 0.6, 'distortion': 0.4},
('random masking', 'greedy sampling'): {'detectability': 0.5, 'distortion': 0.5},
('pseudorandom masking', 'inverse_transform sampling'): {'detectability': 0.75, 'distortion': 0.25},
('pseudorandom masking', 'exponential_minimum sampling'): {'detectability': 0.65, 'distortion': 0.35},
('pseudorandom masking', 'temperature sampling'): {'detectability': 0.55, 'distortion': 0.45},
('pseudorandom masking', 'greedy sampling'): {'detectability': 0.45, 'distortion': 0.55},
('high-entropy masking', 'inverse_transform sampling'): {'detectability': 0.85, 'distortion': 0.15},
('high-entropy masking', 'exponential_minimum sampling'): {'detectability': 0.75, 'distortion': 0.25},
('high-entropy masking', 'temperature sampling'): {'detectability': 0.65, 'distortion': 0.35},
('high-entropy masking', 'greedy sampling'): {'detectability': 0.55, 'distortion': 0.45}
}
}
# Find best combination
best_score = 0
best_combo = None
for combo, metrics in pipeline_metrics['scores'].items():
score = metrics['detectability'] * (1 - metrics['distortion'])
if score > best_score:
best_score = score
best_combo = combo
label_list = ['Input'] + pipeline_metrics['masking_methods'] + pipeline_metrics['sampling_methods'] + ['Output']
source = []
target = []
value = []
colors = []
# Input to masking methods
for i in range(len(pipeline_metrics['masking_methods'])):
source.append(0)
target.append(i + 1)
value.append(1)
colors.append('rgba(0,0,255,0.2)' if pipeline_metrics['masking_methods'][i] != best_combo[0] else 'rgba(255,0,0,0.8)')
# Masking to sampling methods
sampling_start = len(pipeline_metrics['masking_methods']) + 1
for i, mask in enumerate(pipeline_metrics['masking_methods']):
for j, sample in enumerate(pipeline_metrics['sampling_methods']):
score = pipeline_metrics['scores'][(mask, sample)]['detectability'] * \
(1 - pipeline_metrics['scores'][(mask, sample)]['distortion'])
source.append(i + 1)
target.append(sampling_start + j)
value.append(score)
colors.append('rgba(0,0,255,0.2)' if (mask, sample) != best_combo else 'rgba(255,0,0,0.8)')
# Sampling methods to output
output_idx = len(label_list) - 1
for i, sample in enumerate(pipeline_metrics['sampling_methods']):
source.append(sampling_start + i)
target.append(output_idx)
value.append(1)
colors.append('rgba(0,0,255,0.2)' if sample != best_combo[1] else 'rgba(255,0,0,0.8)')
fig = go.Figure(data=[go.Sankey(
node=dict(
pad=15,
thickness=20,
line=dict(color="black", width=0.5),
label=label_list,
color="lightblue"
),
link=dict(
source=source,
target=target,
value=value,
color=colors
)
)])
fig.update_layout(
title_text=f"Watermarking Pipeline Flow<br>Best Combination: {best_combo[0]} + {best_combo[1]}",
font_size=12,
height=500
)
return fig |