import plotly.graph_objects as go def generate_sankey_diagram(): pipeline_metrics = { 'masking_methods': ['random masking', 'pseudorandom masking', 'high-entropy masking'], 'sampling_methods': ['inverse_transform sampling', 'exponential_minimum sampling', 'temperature sampling', 'greedy sampling'], 'scores': { ('random masking', 'inverse_transform sampling'): {'detectability': 0.8, 'distortion': 0.2}, ('random masking', 'exponential_minimum sampling'): {'detectability': 0.7, 'distortion': 0.3}, ('random masking', 'temperature sampling'): {'detectability': 0.6, 'distortion': 0.4}, ('random masking', 'greedy sampling'): {'detectability': 0.5, 'distortion': 0.5}, ('pseudorandom masking', 'inverse_transform sampling'): {'detectability': 0.75, 'distortion': 0.25}, ('pseudorandom masking', 'exponential_minimum sampling'): {'detectability': 0.65, 'distortion': 0.35}, ('pseudorandom masking', 'temperature sampling'): {'detectability': 0.55, 'distortion': 0.45}, ('pseudorandom masking', 'greedy sampling'): {'detectability': 0.45, 'distortion': 0.55}, ('high-entropy masking', 'inverse_transform sampling'): {'detectability': 0.85, 'distortion': 0.15}, ('high-entropy masking', 'exponential_minimum sampling'): {'detectability': 0.75, 'distortion': 0.25}, ('high-entropy masking', 'temperature sampling'): {'detectability': 0.65, 'distortion': 0.35}, ('high-entropy masking', 'greedy sampling'): {'detectability': 0.55, 'distortion': 0.45} } } # Find best combination best_score = 0 best_combo = None for combo, metrics in pipeline_metrics['scores'].items(): score = metrics['detectability'] * (1 - metrics['distortion']) if score > best_score: best_score = score best_combo = combo label_list = ['Input'] + pipeline_metrics['masking_methods'] + pipeline_metrics['sampling_methods'] + ['Output'] source = [] target = [] value = [] colors = [] # Input to masking methods for i in range(len(pipeline_metrics['masking_methods'])): source.append(0) target.append(i + 1) value.append(1) colors.append('rgba(0,0,255,0.2)' if pipeline_metrics['masking_methods'][i] != best_combo[0] else 'rgba(255,0,0,0.8)') # Masking to sampling methods sampling_start = len(pipeline_metrics['masking_methods']) + 1 for i, mask in enumerate(pipeline_metrics['masking_methods']): for j, sample in enumerate(pipeline_metrics['sampling_methods']): score = pipeline_metrics['scores'][(mask, sample)]['detectability'] * \ (1 - pipeline_metrics['scores'][(mask, sample)]['distortion']) source.append(i + 1) target.append(sampling_start + j) value.append(score) colors.append('rgba(0,0,255,0.2)' if (mask, sample) != best_combo else 'rgba(255,0,0,0.8)') # Sampling methods to output output_idx = len(label_list) - 1 for i, sample in enumerate(pipeline_metrics['sampling_methods']): source.append(sampling_start + i) target.append(output_idx) value.append(1) colors.append('rgba(0,0,255,0.2)' if sample != best_combo[1] else 'rgba(255,0,0,0.8)') fig = go.Figure(data=[go.Sankey( node=dict( pad=15, thickness=20, line=dict(color="black", width=0.5), label=label_list, color="lightblue" ), link=dict( source=source, target=target, value=value, color=colors ) )]) fig.update_layout( title_text=f"Watermarking Pipeline Flow
Best Combination: {best_combo[0]} + {best_combo[1]}", font_size=12, height=500 ) return fig