jgyasu's picture
Upload folder using huggingface_hub
2471de4 verified
import plotly.graph_objects as go
def generate_sankey_diagram():
pipeline_metrics = {
'masking_methods': ['random masking', 'pseudorandom masking', 'high-entropy masking'],
'sampling_methods': ['inverse_transform sampling', 'exponential_minimum sampling', 'temperature sampling', 'greedy sampling'],
'scores': {
('random masking', 'inverse_transform sampling'): {'detectability': 0.8, 'distortion': 0.2},
('random masking', 'exponential_minimum sampling'): {'detectability': 0.7, 'distortion': 0.3},
('random masking', 'temperature sampling'): {'detectability': 0.6, 'distortion': 0.4},
('random masking', 'greedy sampling'): {'detectability': 0.5, 'distortion': 0.5},
('pseudorandom masking', 'inverse_transform sampling'): {'detectability': 0.75, 'distortion': 0.25},
('pseudorandom masking', 'exponential_minimum sampling'): {'detectability': 0.65, 'distortion': 0.35},
('pseudorandom masking', 'temperature sampling'): {'detectability': 0.55, 'distortion': 0.45},
('pseudorandom masking', 'greedy sampling'): {'detectability': 0.45, 'distortion': 0.55},
('high-entropy masking', 'inverse_transform sampling'): {'detectability': 0.85, 'distortion': 0.15},
('high-entropy masking', 'exponential_minimum sampling'): {'detectability': 0.75, 'distortion': 0.25},
('high-entropy masking', 'temperature sampling'): {'detectability': 0.65, 'distortion': 0.35},
('high-entropy masking', 'greedy sampling'): {'detectability': 0.55, 'distortion': 0.45}
}
}
# Find best combination
best_score = 0
best_combo = None
for combo, metrics in pipeline_metrics['scores'].items():
score = metrics['detectability'] * (1 - metrics['distortion'])
if score > best_score:
best_score = score
best_combo = combo
label_list = ['Input'] + pipeline_metrics['masking_methods'] + pipeline_metrics['sampling_methods'] + ['Output']
source = []
target = []
value = []
colors = []
# Input to masking methods
for i in range(len(pipeline_metrics['masking_methods'])):
source.append(0)
target.append(i + 1)
value.append(1)
colors.append('rgba(0,0,255,0.2)' if pipeline_metrics['masking_methods'][i] != best_combo[0] else 'rgba(255,0,0,0.8)')
# Masking to sampling methods
sampling_start = len(pipeline_metrics['masking_methods']) + 1
for i, mask in enumerate(pipeline_metrics['masking_methods']):
for j, sample in enumerate(pipeline_metrics['sampling_methods']):
score = pipeline_metrics['scores'][(mask, sample)]['detectability'] * \
(1 - pipeline_metrics['scores'][(mask, sample)]['distortion'])
source.append(i + 1)
target.append(sampling_start + j)
value.append(score)
colors.append('rgba(0,0,255,0.2)' if (mask, sample) != best_combo else 'rgba(255,0,0,0.8)')
# Sampling methods to output
output_idx = len(label_list) - 1
for i, sample in enumerate(pipeline_metrics['sampling_methods']):
source.append(sampling_start + i)
target.append(output_idx)
value.append(1)
colors.append('rgba(0,0,255,0.2)' if sample != best_combo[1] else 'rgba(255,0,0,0.8)')
fig = go.Figure(data=[go.Sankey(
node=dict(
pad=15,
thickness=20,
line=dict(color="black", width=0.5),
label=label_list,
color="lightblue"
),
link=dict(
source=source,
target=target,
value=value,
color=colors
)
)])
fig.update_layout(
title_text=f"Watermarking Pipeline Flow<br>Best Combination: {best_combo[0]} + {best_combo[1]}",
font_size=12,
height=500
)
return fig