dmytromishkin commited on
Commit
6f08eef
1 Parent(s): 724cc25

Create wed.py

Browse files
Files changed (1) hide show
  1. hoho/wed.py +54 -0
hoho/wed.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.spatial.distance import cdist
2
+ from scipy.optimize import linear_sum_assignment
3
+ import numpy as np
4
+
5
+ def compute_WED(pd_vertices, pd_edges, gt_vertices, gt_edges, cv, ce, normalized=True, squared=False):
6
+ pd_vertices = np.array(pd_vertices)
7
+ gt_vertices = np.array(gt_vertices)
8
+ pd_edges = np.array(pd_edges)
9
+ gt_edges = np.array(gt_edges)
10
+
11
+ # Step 1: Bipartite Matching
12
+ if squared:
13
+ distances = cdist(pd_vertices, gt_vertices, metric='sqeuclidean')
14
+ else:
15
+ distances = cdist(pd_vertices, gt_vertices, metric='euclidean')
16
+
17
+ row_ind, col_ind = linear_sum_assignment(distances)
18
+
19
+ # Step 2: Vertex Translation
20
+
21
+ if squared:
22
+ translation_costs = cv * np.sqrt(np.sum(distances[row_ind, col_ind]))
23
+ else:
24
+ translation_costs = cv * np.sum(distances[row_ind, col_ind])
25
+
26
+ # Additional: Vertex Deletion
27
+ unmatched_pd_indices = set(range(len(pd_vertices))) - set(row_ind)
28
+ deletion_costs = cv * len(unmatched_pd_indices) # Assuming a fixed cost for vertex deletion
29
+
30
+ # Step 3: Vertex Insertion
31
+ unmatched_gt_indices = set(range(len(gt_vertices))) - set(col_ind)
32
+ insertion_costs = cv * len(unmatched_gt_indices) # Assuming a fixed cost for vertex insertion
33
+
34
+ # Step 4: Edge Deletion and Insertion
35
+ updated_pd_edges = [(row_ind[np.where(col_ind == edge[0])[0][0]], row_ind[np.where(col_ind == edge[1])[0][0]]) for edge in pd_edges if edge[0] in col_ind and edge[1] in col_ind]
36
+ pd_edges_set = set(map(tuple, updated_pd_edges))
37
+ gt_edges_set = set(map(tuple, gt_edges))
38
+
39
+ # Delete edges not in ground truth
40
+ edges_to_delete = pd_edges_set - gt_edges_set
41
+ deletion_edge_costs = ce * sum(np.linalg.norm(pd_vertices[edge[0]] - pd_vertices[edge[1]]) for edge in edges_to_delete)
42
+
43
+ # Insert missing edges from ground truth
44
+ edges_to_insert = gt_edges_set - pd_edges_set
45
+ insertion_edge_costs = ce * sum(np.linalg.norm(gt_vertices[edge[0]] - gt_vertices[edge[1]]) for edge in edges_to_insert)
46
+
47
+ # Step 5: Calculation of WED
48
+ WED = translation_costs + deletion_costs + insertion_costs + deletion_edge_costs + insertion_edge_costs
49
+
50
+ if normalized:
51
+ total_length_of_gt_edges = np.linalg.norm((gt_vertices[gt_edges[:, 0]] - gt_vertices[gt_edges[:, 1]]), axis=1).sum()
52
+ WED = WED / total_length_of_gt_edges
53
+
54
+ return WED