haoliangtan commited on
Commit
b3ed605
1 Parent(s): 6f8454f

Update eval_onnx.py

Browse files
Files changed (1) hide show
  1. eval_onnx.py +2 -1
eval_onnx.py CHANGED
@@ -141,6 +141,7 @@ def val_imagenet():
141
  val_loader = tqdm(val_loader, file=sys.stdout)
142
  with torch.no_grad():
143
  for batch_idx, (images, targets) in enumerate(val_loader):
 
144
  inputs, targets = images.numpy(), targets
145
  ort_inputs = {ort_session.get_inputs()[0].name: inputs}
146
 
@@ -158,4 +159,4 @@ def val_imagenet():
158
  return top1.avg, top5.avg
159
 
160
  if __name__ == '__main__':
161
- val_imagenet()
 
141
  val_loader = tqdm(val_loader, file=sys.stdout)
142
  with torch.no_grad():
143
  for batch_idx, (images, targets) in enumerate(val_loader):
144
+ images = torch.permute(images, (0, 2, 3, 1))
145
  inputs, targets = images.numpy(), targets
146
  ort_inputs = {ort_session.get_inputs()[0].name: inputs}
147
 
 
159
  return top1.avg, top5.avg
160
 
161
  if __name__ == '__main__':
162
+ val_imagenet()