aiisc-watermarking-model / threeD_plot.py
BheemaShankerNeyigapula's picture
Upload folder using huggingface_hub
ea6afa4 verified
import numpy as np
import plotly.graph_objects as go
from scipy.interpolate import griddata
def gen_three_D_plot(detectability_val, distortion_val, euclidean_val):
# Convert input lists to NumPy arrays
detectability = np.array(detectability_val)
distortion = np.array(distortion_val)
euclidean = np.array(euclidean_val)
# Normalize the values to range [0, 1]
def normalize(data):
min_val, max_val = np.min(data), np.max(data)
return (data - min_val) / (max_val - min_val) if max_val > min_val else np.zeros_like(data)
norm_detectability = normalize(detectability)
norm_distortion = normalize(distortion)
norm_euclidean = normalize(euclidean)
# Composite score: maximize detectability, minimize distortion and Euclidean distance
composite_score = norm_detectability - (norm_distortion + norm_euclidean)
# Sweet spot values
sweet_spot_index = np.argmax(composite_score)
sweet_spot = (detectability[sweet_spot_index], distortion[sweet_spot_index], euclidean[sweet_spot_index])
# Create a meshgrid for interpolation
x_grid, y_grid = np.meshgrid(
np.linspace(np.min(detectability), np.max(detectability), 30),
np.linspace(np.min(distortion), np.max(distortion), 30)
)
# Interpolate z values (Euclidean distances) to fit the grid
z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='linear')
if z_grid is None:
raise ValueError("griddata could not generate a valid interpolation. Check your input data.")
# Create the 3D contour plot with the Plasma color scale
fig = go.Figure(data=go.Surface(
z=z_grid,
x=x_grid,
y=y_grid,
contours={"z": {"show": True, "start": np.min(euclidean), "end": np.max(euclidean), "size": 0.1, "usecolormap": True}},
colorscale='Plasma'
))
# Add a marker for the sweet spot
fig.add_trace(go.Scatter3d(
x=[sweet_spot[0]],
y=[sweet_spot[1]],
z=[sweet_spot[2]],
mode='markers+text',
marker=dict(size=10, color='red', symbol='circle'),
text=["Sweet Spot"],
textposition="top center"
))
# Set axis labels
fig.update_layout(
scene=dict(
xaxis_title='Detectability Score',
yaxis_title='Distortion Score',
zaxis_title='Euclidean Distance'
),
margin=dict(l=0, r=0, b=0, t=0)
)
return fig