Spaces:
Runtime error
Runtime error
import numpy as np | |
import plotly.graph_objects as go | |
from scipy.interpolate import griddata | |
def gen_three_D_plot(detectability_val, distortion_val, euclidean_val): | |
# Convert input lists to NumPy arrays | |
detectability = np.array(detectability_val) | |
distortion = np.array(distortion_val) | |
euclidean = np.array(euclidean_val) | |
# Normalize the values to range [0, 1] | |
def normalize(data): | |
min_val, max_val = np.min(data), np.max(data) | |
return (data - min_val) / (max_val - min_val) if max_val > min_val else np.zeros_like(data) | |
norm_detectability = normalize(detectability) | |
norm_distortion = normalize(distortion) | |
norm_euclidean = normalize(euclidean) | |
# Composite score: maximize detectability, minimize distortion and Euclidean distance | |
composite_score = norm_detectability - (norm_distortion + norm_euclidean) | |
# Sweet spot values | |
sweet_spot_index = np.argmax(composite_score) | |
sweet_spot = (detectability[sweet_spot_index], distortion[sweet_spot_index], euclidean[sweet_spot_index]) | |
# Create a meshgrid for interpolation | |
x_grid, y_grid = np.meshgrid( | |
np.linspace(np.min(detectability), np.max(detectability), 30), | |
np.linspace(np.min(distortion), np.max(distortion), 30) | |
) | |
# Interpolate z values (Euclidean distances) to fit the grid | |
z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='linear') | |
if z_grid is None: | |
raise ValueError("griddata could not generate a valid interpolation. Check your input data.") | |
# Create the 3D contour plot with the Plasma color scale | |
fig = go.Figure(data=go.Surface( | |
z=z_grid, | |
x=x_grid, | |
y=y_grid, | |
contours={"z": {"show": True, "start": np.min(euclidean), "end": np.max(euclidean), "size": 0.1, "usecolormap": True}}, | |
colorscale='Plasma' | |
)) | |
# Add a marker for the sweet spot | |
fig.add_trace(go.Scatter3d( | |
x=[sweet_spot[0]], | |
y=[sweet_spot[1]], | |
z=[sweet_spot[2]], | |
mode='markers+text', | |
marker=dict(size=10, color='red', symbol='circle'), | |
text=["Sweet Spot"], | |
textposition="top center" | |
)) | |
# Set axis labels | |
fig.update_layout( | |
scene=dict( | |
xaxis_title='Detectability Score', | |
yaxis_title='Distortion Score', | |
zaxis_title='Euclidean Distance' | |
), | |
margin=dict(l=0, r=0, b=0, t=0) | |
) | |
return fig | |