nikigoli commited on
Commit
14c8704
·
verified ·
1 Parent(s): a84ffa2

Added a check to check the install of MS Deform Atten

Browse files
Files changed (1) hide show
  1. app.py +80 -1
app.py CHANGED
@@ -49,6 +49,85 @@ class AppSteps(Enum):
49
 
50
  CONF_THRESH = 0.23
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  # MODEL:
53
  def get_args_parser():
54
  """
@@ -164,7 +243,7 @@ def build_model_and_transforms(args):
164
 
165
  parser = argparse.ArgumentParser("Counting Application", parents=[get_args_parser()])
166
  args = parser.parse_args()
167
-
168
  device = get_device()
169
  model, transform = build_model_and_transforms(args)
170
  model = model.to(device)
 
49
 
50
  CONF_THRESH = 0.23
51
 
52
+ @spaces.GPU
53
+ def check_ms_deform_install():
54
+ from __future__ import absolute_import
55
+ from __future__ import print_function
56
+ from __future__ import division
57
+
58
+ import time
59
+ import torch
60
+ import torch.nn as nn
61
+ from torch.autograd import gradcheck
62
+
63
+ from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
64
+
65
+
66
+ N, M, D = 1, 2, 2
67
+ Lq, L, P = 2, 2, 2
68
+ shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
69
+ level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
70
+ S = sum([(H*W).item() for H, W in shapes])
71
+
72
+
73
+ torch.manual_seed(3)
74
+
75
+
76
+ @torch.no_grad()
77
+ def check_forward_equal_with_pytorch_double():
78
+ value = torch.rand(N, S, M, D).cuda() * 0.01
79
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
80
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
81
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
82
+ im2col_step = 2
83
+ output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
84
+ output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
85
+ fwdok = torch.allclose(output_cuda, output_pytorch)
86
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
87
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
88
+
89
+ print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
90
+
91
+
92
+ @torch.no_grad()
93
+ def check_forward_equal_with_pytorch_float():
94
+ value = torch.rand(N, S, M, D).cuda() * 0.01
95
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
96
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
97
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
98
+ im2col_step = 2
99
+ output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
100
+ output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
101
+ fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
102
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
103
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
104
+
105
+ print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
106
+
107
+
108
+ def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
109
+
110
+ value = torch.rand(N, S, M, channels).cuda() * 0.01
111
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
112
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
113
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
114
+ im2col_step = 2
115
+ func = MSDeformAttnFunction.apply
116
+
117
+ value.requires_grad = grad_value
118
+ sampling_locations.requires_grad = grad_sampling_loc
119
+ attention_weights.requires_grad = grad_attn_weight
120
+
121
+ gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
122
+
123
+ print(f'* {gradok} check_gradient_numerical(D={channels})')
124
+
125
+ check_forward_equal_with_pytorch_double()
126
+ check_forward_equal_with_pytorch_float()
127
+
128
+ for channels in [30, 32, 64, 71]:
129
+ check_gradient_numerical(channels, True, True, True)
130
+
131
  # MODEL:
132
  def get_args_parser():
133
  """
 
243
 
244
  parser = argparse.ArgumentParser("Counting Application", parents=[get_args_parser()])
245
  args = parser.parse_args()
246
+ check_ms_deform_install()
247
  device = get_device()
248
  model, transform = build_model_and_transforms(args)
249
  model = model.to(device)