import numpy as np import plotly.graph_objects as go from scipy.interpolate import griddata def gen_three_D_plot(detectability_val, distortion_val, euclidean_val): # Convert input lists to NumPy arrays detectability = np.array(detectability_val) distortion = np.array(distortion_val) euclidean = np.array(euclidean_val) # Normalize the values to range [0, 1] def normalize(data): min_val, max_val = np.min(data), np.max(data) return (data - min_val) / (max_val - min_val) if max_val > min_val else np.zeros_like(data) norm_detectability = normalize(detectability) norm_distortion = normalize(distortion) norm_euclidean = normalize(euclidean) # Composite score: maximize detectability, minimize distortion and Euclidean distance composite_score = norm_detectability - (norm_distortion + norm_euclidean) # Sweet spot values sweet_spot_index = np.argmax(composite_score) sweet_spot = (detectability[sweet_spot_index], distortion[sweet_spot_index], euclidean[sweet_spot_index]) # Create a meshgrid for interpolation x_grid, y_grid = np.meshgrid( np.linspace(np.min(detectability), np.max(detectability), 30), np.linspace(np.min(distortion), np.max(distortion), 30) ) # Interpolate z values (Euclidean distances) to fit the grid z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='linear') if z_grid is None: raise ValueError("griddata could not generate a valid interpolation. Check your input data.") # Create the 3D contour plot with the Plasma color scale fig = go.Figure(data=go.Surface( z=z_grid, x=x_grid, y=y_grid, contours={"z": {"show": True, "start": np.min(euclidean), "end": np.max(euclidean), "size": 0.1, "usecolormap": True}}, colorscale='Plasma' )) # Add a marker for the sweet spot fig.add_trace(go.Scatter3d( x=[sweet_spot[0]], y=[sweet_spot[1]], z=[sweet_spot[2]], mode='markers+text', marker=dict(size=10, color='red', symbol='circle'), text=["Sweet Spot"], textposition="top center" )) # Set axis labels fig.update_layout( scene=dict( xaxis_title='Detectability Score', yaxis_title='Distortion Score', zaxis_title='Euclidean Distance' ), margin=dict(l=0, r=0, b=0, t=0) ) return fig