nikigoli commited on
Commit
a5e9c89
1 Parent(s): 64980f3

Added print statements in transformer file for debugging

Browse files
Files changed (1) hide show
  1. models/GroundingDINO/transformer.py +14 -4
models/GroundingDINO/transformer.py CHANGED
@@ -237,6 +237,7 @@ class Transformer(nn.Module):
237
 
238
  """
239
  # prepare input for encoder
 
240
  src_flatten = []
241
  mask_flatten = []
242
  lvl_pos_embed_flatten = []
@@ -273,7 +274,7 @@ class Transformer(nn.Module):
273
  #########################################################
274
  # Begin Encoder
275
  #########################################################
276
-
277
  memory, memory_text = self.encoder(
278
  src_flatten,
279
  pos=lvl_pos_embed_flatten,
@@ -287,7 +288,7 @@ class Transformer(nn.Module):
287
  position_ids=text_dict["position_ids"],
288
  text_self_attention_masks=text_dict["text_self_attention_masks"],
289
  )
290
-
291
  #########################################################
292
  # End Encoder
293
  # - memory: bs, \sum{hw}, c
@@ -302,9 +303,11 @@ class Transformer(nn.Module):
302
  # import ipdb; ipdb.set_trace()
303
 
304
  if self.two_stage_type == "standard": # 把encoder的输出作为proposal
 
305
  output_memory, output_proposals = gen_encoder_output_proposals(
306
  memory, mask_flatten, spatial_shapes
307
  )
 
308
  output_memory = self.enc_output_norm(self.enc_output(output_memory))
309
 
310
  if text_dict is not None:
@@ -321,24 +324,29 @@ class Transformer(nn.Module):
321
  topk = self.num_queries
322
 
323
  topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
324
-
325
  # gather boxes
 
326
  refpoint_embed_undetach = torch.gather(
327
  enc_outputs_coord_unselected,
328
  1,
329
  topk_proposals.unsqueeze(-1).repeat(1, 1, 4),
330
  ) # unsigmoid
 
331
  refpoint_embed_ = refpoint_embed_undetach.detach()
 
332
  init_box_proposal = torch.gather(
333
  output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
334
  ).sigmoid() # sigmoid
335
-
 
336
  # gather tgt
337
  tgt_undetach = torch.gather(
338
  output_memory,
339
  1,
340
  topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model),
341
  )
 
342
  if self.embed_init_tgt:
343
  tgt_ = (
344
  self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
@@ -393,6 +401,7 @@ class Transformer(nn.Module):
393
  # memory torch.Size([2, 16320, 256])
394
 
395
  # import pdb;pdb.set_trace()
 
396
  hs, references = self.decoder(
397
  tgt=tgt.transpose(0, 1),
398
  memory=memory.transpose(0, 1),
@@ -407,6 +416,7 @@ class Transformer(nn.Module):
407
  text_attention_mask=~text_dict["text_token_mask"],
408
  # we ~ the mask . False means use the token; True means pad the token
409
  )
 
410
  #########################################################
411
  # End Decoder
412
  # hs: n_dec, bs, nq, d_model
 
237
 
238
  """
239
  # prepare input for encoder
240
+ print("inside transformer forward")
241
  src_flatten = []
242
  mask_flatten = []
243
  lvl_pos_embed_flatten = []
 
274
  #########################################################
275
  # Begin Encoder
276
  #########################################################
277
+ print("begin transformer encoder")
278
  memory, memory_text = self.encoder(
279
  src_flatten,
280
  pos=lvl_pos_embed_flatten,
 
288
  position_ids=text_dict["position_ids"],
289
  text_self_attention_masks=text_dict["text_self_attention_masks"],
290
  )
291
+ print("got encoder output")
292
  #########################################################
293
  # End Encoder
294
  # - memory: bs, \sum{hw}, c
 
303
  # import ipdb; ipdb.set_trace()
304
 
305
  if self.two_stage_type == "standard": # 把encoder的输出作为proposal
306
+ print("standard two stage")
307
  output_memory, output_proposals = gen_encoder_output_proposals(
308
  memory, mask_flatten, spatial_shapes
309
  )
310
+ print("got output proposals")
311
  output_memory = self.enc_output_norm(self.enc_output(output_memory))
312
 
313
  if text_dict is not None:
 
324
  topk = self.num_queries
325
 
326
  topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
327
+ print("got topk proposals")
328
  # gather boxes
329
+ print("gather 1")
330
  refpoint_embed_undetach = torch.gather(
331
  enc_outputs_coord_unselected,
332
  1,
333
  topk_proposals.unsqueeze(-1).repeat(1, 1, 4),
334
  ) # unsigmoid
335
+ print("gathered 1")
336
  refpoint_embed_ = refpoint_embed_undetach.detach()
337
+ print("gather 2")
338
  init_box_proposal = torch.gather(
339
  output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
340
  ).sigmoid() # sigmoid
341
+ print("gathered 2")
342
+ print("gather 3")
343
  # gather tgt
344
  tgt_undetach = torch.gather(
345
  output_memory,
346
  1,
347
  topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model),
348
  )
349
+ print("gathered 3")
350
  if self.embed_init_tgt:
351
  tgt_ = (
352
  self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
 
401
  # memory torch.Size([2, 16320, 256])
402
 
403
  # import pdb;pdb.set_trace()
404
+ print("going through decoder")
405
  hs, references = self.decoder(
406
  tgt=tgt.transpose(0, 1),
407
  memory=memory.transpose(0, 1),
 
416
  text_attention_mask=~text_dict["text_token_mask"],
417
  # we ~ the mask . False means use the token; True means pad the token
418
  )
419
+ print("got decoder output")
420
  #########################################################
421
  # End Decoder
422
  # hs: n_dec, bs, nq, d_model