kbrodt commited on
Commit
74d6764
1 Parent(s): 78d75d4

Upload losses.py

Browse files
Files changed (1) hide show
  1. src/losses.py +202 -0
src/losses.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ import pose_estimation
7
+
8
+
9
+ class MSE(nn.Module):
10
+ def __init__(self, ignore=None):
11
+ super().__init__()
12
+
13
+ self.mse = torch.nn.MSELoss(reduction="none")
14
+ self.ignore = ignore if ignore is not None else []
15
+
16
+ def forward(self, y_pred, y_data):
17
+ loss = self.mse(y_pred, y_data)
18
+
19
+ if len(self.ignore) > 0:
20
+ loss[self.ignore] *= 0
21
+
22
+ return loss.sum() / (len(loss) - len(self.ignore))
23
+
24
+
25
+ class Parallel(nn.Module):
26
+ def __init__(self, skeleton, ignore=None, ground_parallel=None):
27
+ super().__init__()
28
+
29
+ self.skeleton = skeleton
30
+ if ignore is not None:
31
+ self.ignore = set(ignore)
32
+ else:
33
+ self.ignore = set()
34
+
35
+ self.ground_parallel = ground_parallel if ground_parallel is not None else []
36
+ self.parallel_in_3d = []
37
+
38
+ self.cos = None
39
+
40
+ def forward(self, y_pred3d, y_data, z, spine_j, global_step=0):
41
+ y_pred = y_pred3d[:, :2]
42
+ rleg, lleg = spine_j
43
+
44
+ Lcon2d = Lcount = 0
45
+ if hasattr(self, "contact_2d"):
46
+ for c2d in self.contact_2d:
47
+ for (
48
+ (src_1, dst_1, t_1),
49
+ (src_2, dst_2, t_2),
50
+ ) in itertools.combinations(c2d, 2):
51
+
52
+ a_1 = torch.lerp(y_data[src_1], y_data[dst_1], t_1)
53
+ a_2 = torch.lerp(y_data[src_2], y_data[dst_2], t_2)
54
+ a = a_2 - a_1
55
+
56
+ b_1 = torch.lerp(y_pred[src_1], y_pred[dst_1], t_1)
57
+ b_2 = torch.lerp(y_pred[src_2], y_pred[dst_2], t_2)
58
+ b = b_2 - b_1
59
+
60
+ lcon2d = ((a - b) ** 2).sum()
61
+ Lcon2d = Lcon2d + lcon2d
62
+ Lcount += 1
63
+
64
+ if Lcount > 0:
65
+ Lcon2d = Lcon2d / Lcount
66
+
67
+ Ltan = Lpar = Lcos = Lcount = 0
68
+ Lspine = 0
69
+ for i, bone in enumerate(self.skeleton):
70
+ if bone in self.ignore:
71
+ continue
72
+
73
+ src, dst = bone
74
+
75
+ b = y_data[dst] - y_data[src]
76
+ t = nn.functional.normalize(b, dim=0)
77
+ n = torch.stack([-t[1], t[0]])
78
+
79
+ if src == 10 and dst == 11: # right leg
80
+ a = rleg
81
+ elif src == 13 and dst == 14: # left leg
82
+ a = lleg
83
+ else:
84
+ a = y_pred[dst] - y_pred[src]
85
+
86
+ bone_name = f"{pose_estimation.KPS[src]}_{pose_estimation.KPS[dst]}"
87
+ c = a - b
88
+ lcos_loc = ltan_loc = lpar_loc = 0
89
+ if self.cos is not None:
90
+ if bone not in [
91
+ (1, 2), # Neck + Right Shoulder
92
+ (1, 5), # Neck + Left Shoulder
93
+ (9, 10), # Hips + Right Upper Leg
94
+ (9, 13), # Hips + Left Upper Leg
95
+ ]:
96
+ a = y_pred[dst] - y_pred[src]
97
+ l2d = torch.norm(a, dim=0)
98
+ l3d = torch.norm(y_pred3d[dst] - y_pred3d[src], dim=0)
99
+ lcos = self.cos[i]
100
+
101
+ lcos_loc = (l2d / l3d - lcos) ** 2
102
+ Lcos = Lcos + lcos_loc
103
+ lpar_loc = ((a / l2d) * n).sum() ** 2
104
+ Lpar = Lpar + lpar_loc
105
+ else:
106
+ ltan_loc = ((c * t).sum()) ** 2
107
+ Ltan = Ltan + ltan_loc
108
+ lpar_loc = (c * n).sum() ** 2
109
+ Lpar = Lpar + lpar_loc
110
+
111
+ Lcount += 1
112
+
113
+ if Lcount > 0:
114
+ Ltan = Ltan / Lcount
115
+ Lcos = Lcos / Lcount
116
+ Lpar = Lpar / Lcount
117
+ Lspine = Lspine / Lcount
118
+
119
+ Lgr = Lcount = 0
120
+ for (src, dst), value in self.ground_parallel:
121
+ bone = y_pred[dst] - y_pred[src]
122
+ bone = nn.functional.normalize(bone, dim=0)
123
+ l = (torch.abs(bone[0]) - value) ** 2
124
+
125
+ Lgr = Lgr + l
126
+ Lcount += 1
127
+
128
+ if Lcount > 0:
129
+ Lgr = Lgr / Lcount
130
+
131
+ Lstraight3d = Lcount = 0
132
+ for (i, j), (k, l) in self.parallel_in_3d:
133
+ a = z[j] - z[i]
134
+ a = nn.functional.normalize(a, dim=0)
135
+ b = z[l] - z[k]
136
+ b = nn.functional.normalize(b, dim=0)
137
+ lo = (((a * b).sum() - 1) ** 2).sum()
138
+ Lstraight3d = Lstraight3d + lo
139
+ Lcount += 1
140
+
141
+ b = y_data[1] - y_data[8]
142
+ b = nn.functional.normalize(b, dim=0)
143
+
144
+ if Lcount > 0:
145
+ Lstraight3d = Lstraight3d / Lcount
146
+
147
+ return Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d
148
+
149
+
150
+ class MimickedSelfContactLoss(nn.Module):
151
+ def __init__(self, geodesics_mask):
152
+ super().__init__()
153
+ """
154
+ Loss that lets vertices in contact on presented mesh attract vertices that are close.
155
+ """
156
+ # geodesic distance mask
157
+ self.register_buffer("geomask", geodesics_mask)
158
+
159
+ def forward(
160
+ self,
161
+ presented_contact,
162
+ vertices,
163
+ v2v=None,
164
+ contact_mode="dist_tanh",
165
+ contact_thresh=1,
166
+ ):
167
+
168
+ contactloss = 0.0
169
+
170
+ if v2v is None:
171
+ # compute pairwise distances
172
+ verts = vertices.contiguous()
173
+ nv = verts.shape[1]
174
+ v2v = verts.squeeze().unsqueeze(1).expand(
175
+ nv, nv, 3
176
+ ) - verts.squeeze().unsqueeze(0).expand(nv, nv, 3)
177
+ v2v = torch.norm(v2v, 2, 2)
178
+
179
+ # loss for self-contact from mimic'ed pose
180
+ if len(presented_contact) > 0:
181
+ # without geodesic distance mask, compute distances
182
+ # between each pair of verts in contact
183
+ with torch.no_grad():
184
+ cvertstobody = v2v[presented_contact, :]
185
+ cvertstobody = cvertstobody[:, presented_contact]
186
+ maskgeo = self.geomask[presented_contact, :]
187
+ maskgeo = maskgeo[:, presented_contact]
188
+ weights = torch.ones_like(cvertstobody).to(verts.device)
189
+ weights[~maskgeo] = float("inf")
190
+ min_idx = torch.min((cvertstobody + 1) * weights, 1)[1]
191
+ min_idx = presented_contact[min_idx.cpu().numpy()]
192
+
193
+ v2v_min = v2v[presented_contact, min_idx]
194
+
195
+ # tanh will not pull vertices that are ~more than contact_thres far apart
196
+ if contact_mode == "dist_tanh":
197
+ contactloss = contact_thresh * torch.tanh(v2v_min / contact_thresh)
198
+ contactloss = contactloss.mean()
199
+ else:
200
+ contactloss = v2v_min.mean()
201
+
202
+ return contactloss