zihengg commited on
Commit
0856822
1 Parent(s): 21794d5

Update eval_onnx.py

Browse files
Files changed (1) hide show
  1. eval_onnx.py +12 -2
eval_onnx.py CHANGED
@@ -514,18 +514,28 @@ if __name__ == '__main__':
514
  data_loader = data.getEvalDataloader()
515
  # Load MoveNet model using ONNX runtime
516
  model = rt.InferenceSession(MODEL_DIR, providers=providers, provider_options=provider_options)
517
-
518
  correct = 0
519
  total = 0
520
  # Loop through the data loader for evaluation
521
  for batch_idx, (imgs, labels, kps_mask, img_names) in enumerate(data_loader):
 
522
  if batch_idx%100 == 0:
523
  print('Finish ',batch_idx)
 
524
  imgs = imgs.detach().cpu().numpy()
525
- output = model.run(['1548','1607','1665','1723'],{'blob.1':imgs})
 
 
 
 
 
526
  pre = movenetDecode(output, kps_mask,mode='output',img_size=IMG_SIZE)
527
  gt = movenetDecode(labels, kps_mask,mode='label',img_size=IMG_SIZE)
 
 
528
  acc = myAcc(pre, gt)
 
529
  correct += sum(acc)
530
  total += len(acc)
531
  # Compute and print accuracy based on evaluated data
 
514
  data_loader = data.getEvalDataloader()
515
  # Load MoveNet model using ONNX runtime
516
  model = rt.InferenceSession(MODEL_DIR, providers=providers, provider_options=provider_options)
517
+
518
  correct = 0
519
  total = 0
520
  # Loop through the data loader for evaluation
521
  for batch_idx, (imgs, labels, kps_mask, img_names) in enumerate(data_loader):
522
+
523
  if batch_idx%100 == 0:
524
  print('Finish ',batch_idx)
525
+
526
  imgs = imgs.detach().cpu().numpy()
527
+ imgs = imgs.transpose((0,2,3,1))
528
+ output = model.run(['1548_transpose','1607_transpose','1665_transpose','1723_transpose'],{'blob.1':imgs})
529
+ output[0] = output[0].transpose((0,3,1,2))
530
+ output[1] = output[1].transpose((0,3,1,2))
531
+ output[2] = output[2].transpose((0,3,1,2))
532
+ output[3] = output[3].transpose((0,3,1,2))
533
  pre = movenetDecode(output, kps_mask,mode='output',img_size=IMG_SIZE)
534
  gt = movenetDecode(labels, kps_mask,mode='label',img_size=IMG_SIZE)
535
+
536
+ #n
537
  acc = myAcc(pre, gt)
538
+
539
  correct += sum(acc)
540
  total += len(acc)
541
  # Compute and print accuracy based on evaluated data