haoliangtan
commited on
Commit
•
b3ed605
1
Parent(s):
6f8454f
Update eval_onnx.py
Browse files- eval_onnx.py +2 -1
eval_onnx.py
CHANGED
@@ -141,6 +141,7 @@ def val_imagenet():
|
|
141 |
val_loader = tqdm(val_loader, file=sys.stdout)
|
142 |
with torch.no_grad():
|
143 |
for batch_idx, (images, targets) in enumerate(val_loader):
|
|
|
144 |
inputs, targets = images.numpy(), targets
|
145 |
ort_inputs = {ort_session.get_inputs()[0].name: inputs}
|
146 |
|
@@ -158,4 +159,4 @@ def val_imagenet():
|
|
158 |
return top1.avg, top5.avg
|
159 |
|
160 |
if __name__ == '__main__':
|
161 |
-
val_imagenet()
|
|
|
141 |
val_loader = tqdm(val_loader, file=sys.stdout)
|
142 |
with torch.no_grad():
|
143 |
for batch_idx, (images, targets) in enumerate(val_loader):
|
144 |
+
images = torch.permute(images, (0, 2, 3, 1))
|
145 |
inputs, targets = images.numpy(), targets
|
146 |
ort_inputs = {ort_session.get_inputs()[0].name: inputs}
|
147 |
|
|
|
159 |
return top1.avg, top5.avg
|
160 |
|
161 |
if __name__ == '__main__':
|
162 |
+
val_imagenet()
|