Spaces:
Running
Running
# import numpy as np | |
# import plotly.graph_objects as go | |
# from scipy.interpolate import griddata | |
# def gen_three_D_plot(detectability_val, distortion_val, euclidean_val): | |
# detectability = np.array(detectability_val) | |
# distortion = np.array(distortion_val) | |
# euclidean = np.array(euclidean_val) | |
# # Find the closest point to the origin | |
# distances_to_origin = np.linalg.norm(np.array([distortion, detectability, euclidean]).T, axis=1) | |
# closest_point_index = np.argmin(distances_to_origin) | |
# # Determine the closest points to each axis | |
# closest_to_x_axis = np.argmin(distortion) | |
# closest_to_y_axis = np.argmin(detectability) | |
# closest_to_z_axis = np.argmin(euclidean) | |
# # Use the detected closest point as the "sweet spot" | |
# sweet_spot_detectability = detectability[closest_point_index] | |
# sweet_spot_distortion = distortion[closest_point_index] | |
# sweet_spot_euclidean = euclidean[closest_point_index] | |
# # Create a meshgrid from the data | |
# x_grid, y_grid = np.meshgrid(np.linspace(min(detectability), max(detectability), 30), | |
# np.linspace(min(distortion), max(distortion), 30)) | |
# # Interpolate z values (Euclidean distances) to fit the grid | |
# z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='linear') | |
# if z_grid is None: | |
# raise ValueError("griddata could not generate a valid interpolation. Check your input data.") | |
# # Create the 3D contour plot with the Plasma color scale | |
# fig = go.Figure(data=go.Surface( | |
# z=z_grid, | |
# x=x_grid, | |
# y=y_grid, | |
# contours={ | |
# "z": {"show": True, "start": min(euclidean), "end": max(euclidean), "size": 0.1, "usecolormap": True} | |
# }, | |
# colorscale='Plasma' | |
# )) | |
# # Add a marker for the sweet spot | |
# fig.add_trace(go.Scatter3d( | |
# x=[sweet_spot_detectability], | |
# y=[sweet_spot_distortion], | |
# z=[sweet_spot_euclidean], | |
# mode='markers+text', | |
# marker=dict(size=10, color='red', symbol='circle'), | |
# text=["Sweet Spot"], | |
# textposition="top center" | |
# )) | |
# # Set axis labels | |
# fig.update_layout( | |
# scene=dict( | |
# xaxis_title='Detectability Score', | |
# yaxis_title='Distortion Score', | |
# zaxis_title='Euclidean Distance' | |
# ), | |
# margin=dict(l=0, r=0, b=0, t=0) | |
# ) | |
# return fig | |
import numpy as np | |
import plotly.graph_objects as go | |
from scipy.interpolate import griddata | |
def gen_three_D_plot(detectability_val, distortion_val, euclidean_val): | |
detectability = np.array(detectability_val) | |
distortion = np.array(distortion_val) | |
euclidean = np.array(euclidean_val) | |
# Normalize the values to range [0, 1] | |
norm_detectability = (detectability - min(detectability)) / (max(detectability) - min(detectability)) | |
norm_distortion = (distortion - min(distortion)) / (max(distortion) - min(distortion)) | |
norm_euclidean = (euclidean - min(euclidean)) / (max(euclidean) - min(euclidean)) | |
# Composite score: maximize detectability, minimize distortion and Euclidean distance | |
# We subtract distortion and euclidean as we want them minimized. | |
composite_score = norm_detectability - (norm_distortion + norm_euclidean) | |
# Find the index of the maximum score (sweet spot) | |
sweet_spot_index = np.argmax(composite_score) | |
# Sweet spot values | |
sweet_spot_detectability = detectability[sweet_spot_index] | |
sweet_spot_distortion = distortion[sweet_spot_index] | |
sweet_spot_euclidean = euclidean[sweet_spot_index] | |
# Create a meshgrid from the data | |
x_grid, y_grid = np.meshgrid(np.linspace(min(detectability), max(detectability), 30), | |
np.linspace(min(distortion), max(distortion), 30)) | |
# Interpolate z values (Euclidean distances) to fit the grid | |
z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='linear') | |
if z_grid is None: | |
raise ValueError("griddata could not generate a valid interpolation. Check your input data.") | |
# Create the 3D contour plot with the Plasma color scale | |
fig = go.Figure(data=go.Surface( | |
z=z_grid, | |
x=x_grid, | |
y=y_grid, | |
contours={ | |
"z": {"show": True, "start": min(euclidean), "end": max(euclidean), "size": 0.1, "usecolormap": True} | |
}, | |
colorscale='Plasma' | |
)) | |
# Add a marker for the sweet spot | |
fig.add_trace(go.Scatter3d( | |
x=[sweet_spot_detectability], | |
y=[sweet_spot_distortion], | |
z=[sweet_spot_euclidean], | |
mode='markers+text', | |
marker=dict(size=10, color='red', symbol='circle'), | |
text=["Sweet Spot"], | |
textposition="top center" | |
)) | |
# Set axis labels | |
fig.update_layout( | |
scene=dict( | |
xaxis_title='Detectability Score', | |
yaxis_title='Distortion Score', | |
zaxis_title='Euclidean Distance' | |
), | |
margin=dict(l=0, r=0, b=0, t=0) | |
) | |
return fig | |