Spaces:
Running
on
T4
Running
on
T4
Added print statements in transformer file for debugging
Browse files
models/GroundingDINO/transformer.py
CHANGED
@@ -237,6 +237,7 @@ class Transformer(nn.Module):
|
|
237 |
|
238 |
"""
|
239 |
# prepare input for encoder
|
|
|
240 |
src_flatten = []
|
241 |
mask_flatten = []
|
242 |
lvl_pos_embed_flatten = []
|
@@ -273,7 +274,7 @@ class Transformer(nn.Module):
|
|
273 |
#########################################################
|
274 |
# Begin Encoder
|
275 |
#########################################################
|
276 |
-
|
277 |
memory, memory_text = self.encoder(
|
278 |
src_flatten,
|
279 |
pos=lvl_pos_embed_flatten,
|
@@ -287,7 +288,7 @@ class Transformer(nn.Module):
|
|
287 |
position_ids=text_dict["position_ids"],
|
288 |
text_self_attention_masks=text_dict["text_self_attention_masks"],
|
289 |
)
|
290 |
-
|
291 |
#########################################################
|
292 |
# End Encoder
|
293 |
# - memory: bs, \sum{hw}, c
|
@@ -302,9 +303,11 @@ class Transformer(nn.Module):
|
|
302 |
# import ipdb; ipdb.set_trace()
|
303 |
|
304 |
if self.two_stage_type == "standard": # 把encoder的输出作为proposal
|
|
|
305 |
output_memory, output_proposals = gen_encoder_output_proposals(
|
306 |
memory, mask_flatten, spatial_shapes
|
307 |
)
|
|
|
308 |
output_memory = self.enc_output_norm(self.enc_output(output_memory))
|
309 |
|
310 |
if text_dict is not None:
|
@@ -321,24 +324,29 @@ class Transformer(nn.Module):
|
|
321 |
topk = self.num_queries
|
322 |
|
323 |
topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
|
324 |
-
|
325 |
# gather boxes
|
|
|
326 |
refpoint_embed_undetach = torch.gather(
|
327 |
enc_outputs_coord_unselected,
|
328 |
1,
|
329 |
topk_proposals.unsqueeze(-1).repeat(1, 1, 4),
|
330 |
) # unsigmoid
|
|
|
331 |
refpoint_embed_ = refpoint_embed_undetach.detach()
|
|
|
332 |
init_box_proposal = torch.gather(
|
333 |
output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
|
334 |
).sigmoid() # sigmoid
|
335 |
-
|
|
|
336 |
# gather tgt
|
337 |
tgt_undetach = torch.gather(
|
338 |
output_memory,
|
339 |
1,
|
340 |
topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model),
|
341 |
)
|
|
|
342 |
if self.embed_init_tgt:
|
343 |
tgt_ = (
|
344 |
self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
|
@@ -393,6 +401,7 @@ class Transformer(nn.Module):
|
|
393 |
# memory torch.Size([2, 16320, 256])
|
394 |
|
395 |
# import pdb;pdb.set_trace()
|
|
|
396 |
hs, references = self.decoder(
|
397 |
tgt=tgt.transpose(0, 1),
|
398 |
memory=memory.transpose(0, 1),
|
@@ -407,6 +416,7 @@ class Transformer(nn.Module):
|
|
407 |
text_attention_mask=~text_dict["text_token_mask"],
|
408 |
# we ~ the mask . False means use the token; True means pad the token
|
409 |
)
|
|
|
410 |
#########################################################
|
411 |
# End Decoder
|
412 |
# hs: n_dec, bs, nq, d_model
|
|
|
237 |
|
238 |
"""
|
239 |
# prepare input for encoder
|
240 |
+
print("inside transformer forward")
|
241 |
src_flatten = []
|
242 |
mask_flatten = []
|
243 |
lvl_pos_embed_flatten = []
|
|
|
274 |
#########################################################
|
275 |
# Begin Encoder
|
276 |
#########################################################
|
277 |
+
print("begin transformer encoder")
|
278 |
memory, memory_text = self.encoder(
|
279 |
src_flatten,
|
280 |
pos=lvl_pos_embed_flatten,
|
|
|
288 |
position_ids=text_dict["position_ids"],
|
289 |
text_self_attention_masks=text_dict["text_self_attention_masks"],
|
290 |
)
|
291 |
+
print("got encoder output")
|
292 |
#########################################################
|
293 |
# End Encoder
|
294 |
# - memory: bs, \sum{hw}, c
|
|
|
303 |
# import ipdb; ipdb.set_trace()
|
304 |
|
305 |
if self.two_stage_type == "standard": # 把encoder的输出作为proposal
|
306 |
+
print("standard two stage")
|
307 |
output_memory, output_proposals = gen_encoder_output_proposals(
|
308 |
memory, mask_flatten, spatial_shapes
|
309 |
)
|
310 |
+
print("got output proposals")
|
311 |
output_memory = self.enc_output_norm(self.enc_output(output_memory))
|
312 |
|
313 |
if text_dict is not None:
|
|
|
324 |
topk = self.num_queries
|
325 |
|
326 |
topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
|
327 |
+
print("got topk proposals")
|
328 |
# gather boxes
|
329 |
+
print("gather 1")
|
330 |
refpoint_embed_undetach = torch.gather(
|
331 |
enc_outputs_coord_unselected,
|
332 |
1,
|
333 |
topk_proposals.unsqueeze(-1).repeat(1, 1, 4),
|
334 |
) # unsigmoid
|
335 |
+
print("gathered 1")
|
336 |
refpoint_embed_ = refpoint_embed_undetach.detach()
|
337 |
+
print("gather 2")
|
338 |
init_box_proposal = torch.gather(
|
339 |
output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
|
340 |
).sigmoid() # sigmoid
|
341 |
+
print("gathered 2")
|
342 |
+
print("gather 3")
|
343 |
# gather tgt
|
344 |
tgt_undetach = torch.gather(
|
345 |
output_memory,
|
346 |
1,
|
347 |
topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model),
|
348 |
)
|
349 |
+
print("gathered 3")
|
350 |
if self.embed_init_tgt:
|
351 |
tgt_ = (
|
352 |
self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
|
|
|
401 |
# memory torch.Size([2, 16320, 256])
|
402 |
|
403 |
# import pdb;pdb.set_trace()
|
404 |
+
print("going through decoder")
|
405 |
hs, references = self.decoder(
|
406 |
tgt=tgt.transpose(0, 1),
|
407 |
memory=memory.transpose(0, 1),
|
|
|
416 |
text_attention_mask=~text_dict["text_token_mask"],
|
417 |
# we ~ the mask . False means use the token; True means pad the token
|
418 |
)
|
419 |
+
print("got decoder output")
|
420 |
#########################################################
|
421 |
# End Decoder
|
422 |
# hs: n_dec, bs, nq, d_model
|