File size: 4,706 Bytes
6f08eef
 
 
 
11827d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f08eef
 
11827d2
 
 
 
 
 
 
 
 
 
 
 
6f08eef
11827d2
6f08eef
 
11827d2
6f08eef
11827d2
6f08eef
 
11827d2
6f08eef
 
 
11827d2
6f08eef
 
 
11827d2
6f08eef
 
11827d2
 
 
 
6f08eef
 
 
11827d2
 
 
 
6f08eef
 
 
 
 
 
 
11827d2
6f08eef
 
 
 
11827d2
 
6f08eef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from scipy.spatial.distance import cdist
from scipy.optimize import linear_sum_assignment
import numpy as np 


def preregister_mean_std(verts_to_transform, target_verts, single_scale=True):
    mu_target = target_verts.mean(axis=0)
    mu_in = verts_to_transform.mean(axis=0)
    std_target = np.std(target_verts, axis=0)
    std_in = np.std(verts_to_transform, axis=0)
    
    if np.any(std_in == 0):
        std_in[std_in == 0] = 1
    if np.any(std_target == 0):
        std_target[std_target == 0] = 1
    if np.any(np.isnan(std_in)):
        std_in[np.isnan(std_in)] = 1
    if np.any(np.isnan(std_target)):
        std_target[np.isnan(std_target)] = 1
        
    if single_scale:
        std_target = np.linalg.norm(std_target)
        std_in = np.linalg.norm(std_in)
    
    transformed_verts = (verts_to_transform - mu_in) / std_in
    transformed_verts = transformed_verts * std_target + mu_target
    
    return transformed_verts


def compute_WED(pd_vertices, pd_edges, gt_vertices, gt_edges, cv=-1, ce=1.0, normalized=True, preregister=True, single_scale=True):
    '''The function computes the Wireframe Edge Distance (WED) between two graphs.
    pd_vertices: list of predicted vertices
    pd_edges: list of predicted edges
    gt_vertices: list of ground truth vertices
    gt_edges: list of ground truth edges
    cv: vertex cost (the cost in centimeters of missing a vertex, default is -1, which means 1/4 of the diameter of the ground truth mesh)
    ce: edge cost (multiplier of the edge length for edge deletion and insertion, default is 1.0)
    normalized: if True, the WED is normalized by the total length of the ground truth edges
    preregister: if True, the predicted vertices have their mean and scale matched to the ground truth vertices
    '''
    
    # Vertex coordinates are in centimeters. When cv and ce are set to 100.0 and 1.0 respectively, 
    # missing a vertex is equivanlent predicting it 1 meter away from the ground truth vertex.
    # This is equivalent to setting cv=1 and ce=1 when the vertex coordinates are in meters.
    # When a negative cv value is set (the default behavior), cv is reset to 1/4 of the diameter of the ground truth wireframe.
    
    pd_vertices = np.array(pd_vertices)
    gt_vertices = np.array(gt_vertices)
    
    diameter = cdist(gt_vertices, gt_vertices).max()
    
    if cv < 0:
        cv = diameter / 4.0 
        # Cost of addining or deleting a vertex is set to 1/4 of the diameter of the ground truth mesh
    
    # Step 0: Prenormalize / preregister
    if preregister:
        pd_vertices = preregister_mean_std(pd_vertices, gt_vertices, single_scale=single_scale)
        

    pd_edges = np.array(pd_edges)
    gt_edges = np.array(gt_edges)        
    
    # Step 1: Bipartite Matching
    distances = cdist(pd_vertices, gt_vertices, metric='euclidean')
    row_ind, col_ind = linear_sum_assignment(distances)
        
    
    # Step 2: Vertex Translation
    translation_costs = np.sum(distances[row_ind, col_ind])
    
    # Additional: Vertex Deletion
    unmatched_pd_indices = set(range(len(pd_vertices))) - set(row_ind)
    deletion_costs = cv * len(unmatched_pd_indices)  
    
    # Step 3: Vertex Insertion
    unmatched_gt_indices = set(range(len(gt_vertices))) - set(col_ind)
    insertion_costs = cv * len(unmatched_gt_indices)  
    
    # Step 4: Edge Deletion and Insertion
    updated_pd_edges = [(col_ind[np.where(row_ind == edge[0])[0][0]], col_ind[np.where(row_ind == edge[1])[0][0]]) for edge in pd_edges if edge[0] in row_ind and edge[1] in row_ind]
    pd_edges_set = set(map(tuple, [set(edge) for edge in updated_pd_edges]))
    gt_edges_set = set(map(tuple, [set(edge) for edge in gt_edges]))

    
    # Delete edges not in ground truth
    edges_to_delete = pd_edges_set - gt_edges_set
    
    vert_tf = [np.where(col_ind == v)[0][0] if v in col_ind else 0 for v in range(len(gt_vertices))]
    deletion_edge_costs = ce * sum(np.linalg.norm(pd_vertices[vert_tf[edge[0]]] - pd_vertices[vert_tf[edge[1]]]) for edge in edges_to_delete)

    
    # Insert missing edges from ground truth
    edges_to_insert = gt_edges_set - pd_edges_set
    insertion_edge_costs = ce * sum(np.linalg.norm(gt_vertices[edge[0]] - gt_vertices[edge[1]]) for edge in edges_to_insert) 
    
    # Step 5: Calculation of WED
    WED = translation_costs + deletion_costs + insertion_costs + deletion_edge_costs + insertion_edge_costs

    
    if normalized:
        total_length_of_gt_edges = np.linalg.norm((gt_vertices[gt_edges[:, 0]] - gt_vertices[gt_edges[:, 1]]), axis=1).sum()
        WED = WED / total_length_of_gt_edges
        
    # print ("Total length", total_length_of_gt_edges)
    return WED