zhengrongzhang hangyang-amd commited on
Commit
f201632
1 Parent(s): 0b5f4ac

Update eval_onnx.py (#2)

Browse files

- Update eval_onnx.py (806b7855325fa5ebf88011d2eb16a555f587582c)


Co-authored-by: hang yang <hangyang-amd@users.noreply.huggingface.co>

Files changed (1) hide show
  1. eval_onnx.py +2 -1
eval_onnx.py CHANGED
@@ -34,6 +34,7 @@ parser.add_argument(
34
  default="vaip_config.json",
35
  help="Path of the config file for seting provider_options.",
36
  )
 
37
  args = parser.parse_args()
38
 
39
  class AverageMeter(object):
@@ -144,7 +145,7 @@ def val_imagenet():
144
  val_loader = tqdm(val_loader, file=sys.stdout)
145
  with torch.no_grad():
146
  for batch_idx, (images, targets) in enumerate(val_loader):
147
- inputs, targets = images.numpy(), targets
148
  ort_inputs = {ort_session.get_inputs()[0].name: inputs}
149
 
150
  outputs = ort_session.run(None, ort_inputs)
 
34
  default="vaip_config.json",
35
  help="Path of the config file for seting provider_options.",
36
  )
37
+ parser.add_argument('--data_format', type=str, choices=["nchw", "nhwc"], default="nchw")
38
  args = parser.parse_args()
39
 
40
  class AverageMeter(object):
 
145
  val_loader = tqdm(val_loader, file=sys.stdout)
146
  with torch.no_grad():
147
  for batch_idx, (images, targets) in enumerate(val_loader):
148
+ inputs, targets = images.numpy() if args.data_format == "nchw" else images.permute((0, 2, 3, 1)).numpy(), targets
149
  ort_inputs = {ort_session.get_inputs()[0].name: inputs}
150
 
151
  outputs = ort_session.run(None, ort_inputs)