# import numpy as np # import plotly.graph_objects as go # from scipy.interpolate import griddata # def gen_three_D_plot(detectability_val, distortion_val, euclidean_val): # detectability = np.array(detectability_val) # distortion = np.array(distortion_val) # euclidean = np.array(euclidean_val) # # Find the closest point to the origin # distances_to_origin = np.linalg.norm(np.array([distortion, detectability, euclidean]).T, axis=1) # closest_point_index = np.argmin(distances_to_origin) # # Determine the closest points to each axis # closest_to_x_axis = np.argmin(distortion) # closest_to_y_axis = np.argmin(detectability) # closest_to_z_axis = np.argmin(euclidean) # # Use the detected closest point as the "sweet spot" # sweet_spot_detectability = detectability[closest_point_index] # sweet_spot_distortion = distortion[closest_point_index] # sweet_spot_euclidean = euclidean[closest_point_index] # # Create a meshgrid from the data # x_grid, y_grid = np.meshgrid(np.linspace(min(detectability), max(detectability), 30), # np.linspace(min(distortion), max(distortion), 30)) # # Interpolate z values (Euclidean distances) to fit the grid # z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='linear') # if z_grid is None: # raise ValueError("griddata could not generate a valid interpolation. Check your input data.") # # Create the 3D contour plot with the Plasma color scale # fig = go.Figure(data=go.Surface( # z=z_grid, # x=x_grid, # y=y_grid, # contours={ # "z": {"show": True, "start": min(euclidean), "end": max(euclidean), "size": 0.1, "usecolormap": True} # }, # colorscale='Plasma' # )) # # Add a marker for the sweet spot # fig.add_trace(go.Scatter3d( # x=[sweet_spot_detectability], # y=[sweet_spot_distortion], # z=[sweet_spot_euclidean], # mode='markers+text', # marker=dict(size=10, color='red', symbol='circle'), # text=["Sweet Spot"], # textposition="top center" # )) # # Set axis labels # fig.update_layout( # scene=dict( # xaxis_title='Detectability Score', # yaxis_title='Distortion Score', # zaxis_title='Euclidean Distance' # ), # margin=dict(l=0, r=0, b=0, t=0) # ) # return fig import numpy as np import plotly.graph_objects as go from scipy.interpolate import griddata def gen_three_D_plot(detectability_val, distortion_val, euclidean_val): detectability = np.array(detectability_val) distortion = np.array(distortion_val) euclidean = np.array(euclidean_val) # Normalize the values to range [0, 1] norm_detectability = (detectability - min(detectability)) / (max(detectability) - min(detectability)) norm_distortion = (distortion - min(distortion)) / (max(distortion) - min(distortion)) norm_euclidean = (euclidean - min(euclidean)) / (max(euclidean) - min(euclidean)) # Composite score: maximize detectability, minimize distortion and Euclidean distance # We subtract distortion and euclidean as we want them minimized. composite_score = norm_detectability - (norm_distortion + norm_euclidean) # Find the index of the maximum score (sweet spot) sweet_spot_index = np.argmax(composite_score) # Sweet spot values sweet_spot_detectability = detectability[sweet_spot_index] sweet_spot_distortion = distortion[sweet_spot_index] sweet_spot_euclidean = euclidean[sweet_spot_index] # Create a meshgrid from the data x_grid, y_grid = np.meshgrid(np.linspace(min(detectability), max(detectability), 30), np.linspace(min(distortion), max(distortion), 30)) # Interpolate z values (Euclidean distances) to fit the grid z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='linear') if z_grid is None: raise ValueError("griddata could not generate a valid interpolation. Check your input data.") # Create the 3D contour plot with the Plasma color scale fig = go.Figure(data=go.Surface( z=z_grid, x=x_grid, y=y_grid, contours={ "z": {"show": True, "start": min(euclidean), "end": max(euclidean), "size": 0.1, "usecolormap": True} }, colorscale='Plasma' )) # Add a marker for the sweet spot fig.add_trace(go.Scatter3d( x=[sweet_spot_detectability], y=[sweet_spot_distortion], z=[sweet_spot_euclidean], mode='markers+text', marker=dict(size=10, color='red', symbol='circle'), text=["Sweet Spot"], textposition="top center" )) # Set axis labels fig.update_layout( scene=dict( xaxis_title='Detectability Score', yaxis_title='Distortion Score', zaxis_title='Euclidean Distance' ), margin=dict(l=0, r=0, b=0, t=0) ) return fig