andreysher
commited on
Commit
•
4d679c2
1
Parent(s):
632e862
imagenet-benchmark
Browse files- MobileNetV2/MobileNetV2-ENOT.onnx +3 -0
- MobileNetV2/MobileNetV2-ENOT.pth +3 -0
- README.md +58 -0
- ResNet-50/ResNet50-ENOT-x2.onnx +3 -0
- ResNet-50/ResNet50-ENOT-x2.pth +3 -0
- ResNet-50/ResNet50-ENOT-x4.onnx +3 -0
- ResNet-50/ResNet50-ENOT-x4.pth +3 -0
- ViT-B-32/ViT-B-32-ENOT.onnx +3 -0
- ViT-B-32/ViT-B-32-ENOT.pth +3 -0
- measure_mac.py +28 -0
- requirements.txt +5 -0
- test.py +180 -0
- utils.py +208 -0
MobileNetV2/MobileNetV2-ENOT.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fae5b0822282cce7cec83d63b96af7bd12deae8e8371083b28a9bc6002e08a7d
|
3 |
+
size 10682115
|
MobileNetV2/MobileNetV2-ENOT.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0d39fa80cba1d431eea3009c7ae0bf506fb7e6c6c97853994329ebc03a1fc40e
|
3 |
+
size 32641690
|
README.md
CHANGED
@@ -1,3 +1,61 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
+
datasets:
|
4 |
+
- imagenet-1k
|
5 |
+
library_name: torchvision
|
6 |
+
pipeline_tag: image-classification
|
7 |
+
tags:
|
8 |
+
- onnx
|
9 |
+
- ENOT-AutoDL
|
10 |
---
|
11 |
+
|
12 |
+
# ENOT-AutoDL pruning benchmark on ImageNet-1k
|
13 |
+
|
14 |
+
This repository contains models accelerated with [ENOT-AutoDL](https://pypi.org/project/enot-autodl/) framework.
|
15 |
+
Models from [Torchvision](https://pytorch.org/vision/stable/models.html) are used as a baseline.
|
16 |
+
Evaluation code is also based on Torchvision references.
|
17 |
+
|
18 |
+
## ResNet-50
|
19 |
+
|
20 |
+
| Model | Latency (MMACs) | Accuracy (%) |
|
21 |
+
| ------------------------- | :---------------: | :-------------: |
|
22 |
+
| **ResNet-50 Torchvision** | 4144.854 | 76.144 |
|
23 |
+
| **ResNet-50 ENOT (x2)** | 2057.615 (x2.014) | 75.482 (-0.662) |
|
24 |
+
| **ResNet-50 ENOT (x4)** | 867.943 (x4.775) | 73.576 (-2.568) |
|
25 |
+
|
26 |
+
## ViT-B/32
|
27 |
+
|
28 |
+
| Model | Latency (MMACs) | Accuracy (%) |
|
29 |
+
| ------------------------ | :--------------: | :-------------: |
|
30 |
+
| **ViT-B/32 Torchvision** | 4413.986 | 75.912 |
|
31 |
+
| **ViT-B/32 ENOT** | 492.232 (x8.967) | 73.718 (-2.194) |
|
32 |
+
|
33 |
+
## MobileNetV2
|
34 |
+
|
35 |
+
| Model | Latency (MMACs) | Accuracy (%) |
|
36 |
+
| --------------------------- | :--------------: | :------------: |
|
37 |
+
| **MobileNetV2 Torchvision** | 334.227 | 71.878 |
|
38 |
+
| **MobileNetV2 ENOT** | 156.800 (x2.131) | 69.898 (-1.98) |
|
39 |
+
|
40 |
+
# Validation
|
41 |
+
|
42 |
+
To validate results, follow this steps:
|
43 |
+
|
44 |
+
1. Install all required packages:
|
45 |
+
```bash
|
46 |
+
pip install -r requrements.txt
|
47 |
+
```
|
48 |
+
1. Calculate model latency:
|
49 |
+
```bash
|
50 |
+
python measure_mac.py --model-ckpt path/to/model.pth
|
51 |
+
```
|
52 |
+
1. Measure accuracy of ONNX model:
|
53 |
+
```bash
|
54 |
+
python test.py --data-path path/to/imagenet --model-onnx path/to/model.onnx --batch-size 1
|
55 |
+
```
|
56 |
+
1. Measure accuracy of PyTorch (.pth) model:
|
57 |
+
```bash
|
58 |
+
python test.py --data-path path/to/imagenet --model-ckpt path/to/model.pth
|
59 |
+
```
|
60 |
+
|
61 |
+
If you want to book a demo, please [contact us](enot@enot.ai).
|
ResNet-50/ResNet50-ENOT-x2.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9f689ec182909427df72d390d425eb3b72618d4c40ae089b93b66ea14c6adf5f
|
3 |
+
size 50666788
|
ResNet-50/ResNet50-ENOT-x2.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c0b29e2ac563332d02274d6d656379d3b0957b91b7c8b6c1b4433657d74d6e68
|
3 |
+
size 101839301
|
ResNet-50/ResNet50-ENOT-x4.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:387b705b1d83c844f513d7646f95138f8fcfb420e1ef0b5f8d7039e550c66b91
|
3 |
+
size 20850032
|
ResNet-50/ResNet50-ENOT-x4.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7a9b7d6ac9062b92da9b44f61ace8c62da76ce86fb0947fdb40fb449792e194a
|
3 |
+
size 62177349
|
ViT-B-32/ViT-B-32-ENOT.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:63a4d0e19cfbeca9dca0b18aaf5c60b2d845c05a25e7641b954e90839efda63b
|
3 |
+
size 39430730
|
ViT-B-32/ViT-B-32-ENOT.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:92a81cef913af4012215215400049317168939624f09d76e9043aee2342af356
|
3 |
+
size 157444613
|
measure_mac.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from fvcore.nn import FlopCountAnalysis
|
5 |
+
|
6 |
+
|
7 |
+
def get_args():
|
8 |
+
parser = argparse.ArgumentParser()
|
9 |
+
parser.add_argument("--model-ckpt", type=str)
|
10 |
+
|
11 |
+
return parser.parse_args()
|
12 |
+
|
13 |
+
|
14 |
+
def main():
|
15 |
+
args = get_args()
|
16 |
+
|
17 |
+
checkpoint = torch.load(args.model_ckpt, map_location="cpu")
|
18 |
+
model = checkpoint["model_ckpt"]
|
19 |
+
model.eval()
|
20 |
+
|
21 |
+
flops = FlopCountAnalysis(model.cpu(), torch.ones((1, 3, 224, 224)))
|
22 |
+
flops = flops.total()
|
23 |
+
|
24 |
+
print(f"MMACs = {flops/1e6}")
|
25 |
+
|
26 |
+
|
27 |
+
if __name__ == "__main__":
|
28 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.13.1
|
2 |
+
torchvision==0.14.1
|
3 |
+
fvcore==0.1.5.post20221221
|
4 |
+
onnxruntime-gpu==1.15.1
|
5 |
+
onnx==1.13.1
|
test.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import onnxruntime
|
4 |
+
import torch
|
5 |
+
import torch.utils.data
|
6 |
+
import torchvision
|
7 |
+
from torch import nn
|
8 |
+
from torchvision.transforms.functional import InterpolationMode
|
9 |
+
|
10 |
+
import utils
|
11 |
+
|
12 |
+
|
13 |
+
def evaluate(
|
14 |
+
criterion,
|
15 |
+
data_loader,
|
16 |
+
device,
|
17 |
+
model=None,
|
18 |
+
model_onnx_path=None,
|
19 |
+
print_freq=100,
|
20 |
+
log_suffix="",
|
21 |
+
):
|
22 |
+
if model_onnx_path:
|
23 |
+
session = onnxruntime.InferenceSession(
|
24 |
+
model_onnx_path, providers=["CPUExecutionProvider"]
|
25 |
+
)
|
26 |
+
input_name = session.get_inputs()[0].name
|
27 |
+
|
28 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
29 |
+
header = f"Test: {log_suffix}"
|
30 |
+
|
31 |
+
num_processed_samples = 0
|
32 |
+
with torch.inference_mode():
|
33 |
+
for image, target in metric_logger.log_every(data_loader, print_freq, header):
|
34 |
+
target = target.to(device, non_blocking=True)
|
35 |
+
image = image.to(device)
|
36 |
+
|
37 |
+
if model_onnx_path:
|
38 |
+
# from torch to numpy (ort)
|
39 |
+
input_data = image.cpu().numpy()
|
40 |
+
|
41 |
+
output_data = session.run([], {input_name: input_data})[0]
|
42 |
+
|
43 |
+
# from numpy to torch
|
44 |
+
output = torch.from_numpy(output_data).to(device)
|
45 |
+
elif model:
|
46 |
+
output = model(image)
|
47 |
+
|
48 |
+
loss = criterion(output, target)
|
49 |
+
|
50 |
+
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
|
51 |
+
# FIXME need to take into account that the datasets
|
52 |
+
# could have been padded in distributed setup
|
53 |
+
batch_size = image.shape[0]
|
54 |
+
metric_logger.update(loss=loss.item())
|
55 |
+
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
|
56 |
+
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
|
57 |
+
num_processed_samples += batch_size
|
58 |
+
# gather the stats from all processes
|
59 |
+
|
60 |
+
metric_logger.synchronize_between_processes()
|
61 |
+
|
62 |
+
print(
|
63 |
+
f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}"
|
64 |
+
)
|
65 |
+
return metric_logger.acc1.global_avg
|
66 |
+
|
67 |
+
|
68 |
+
def load_data(valdir):
|
69 |
+
# Data loading code
|
70 |
+
print("Loading data")
|
71 |
+
interpolation = InterpolationMode("bilinear")
|
72 |
+
|
73 |
+
preprocessing = torchvision.transforms.Compose(
|
74 |
+
[
|
75 |
+
torchvision.transforms.Resize(256, interpolation=interpolation),
|
76 |
+
torchvision.transforms.CenterCrop(224),
|
77 |
+
torchvision.transforms.PILToTensor(),
|
78 |
+
torchvision.transforms.ConvertImageDtype(torch.float),
|
79 |
+
torchvision.transforms.Normalize(
|
80 |
+
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
|
81 |
+
),
|
82 |
+
]
|
83 |
+
)
|
84 |
+
|
85 |
+
dataset_test = torchvision.datasets.ImageFolder(
|
86 |
+
valdir,
|
87 |
+
preprocessing,
|
88 |
+
)
|
89 |
+
|
90 |
+
print("Creating data loaders")
|
91 |
+
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
92 |
+
|
93 |
+
return dataset_test, test_sampler
|
94 |
+
|
95 |
+
|
96 |
+
def main(args):
|
97 |
+
print(args)
|
98 |
+
|
99 |
+
if torch.cuda.is_available():
|
100 |
+
device = torch.device("cuda")
|
101 |
+
else:
|
102 |
+
device = torch.device("cpu")
|
103 |
+
|
104 |
+
val_dir = os.path.join(args.data_path, "val")
|
105 |
+
dataset_test, test_sampler = load_data(val_dir)
|
106 |
+
|
107 |
+
data_loader_test = torch.utils.data.DataLoader(
|
108 |
+
dataset_test,
|
109 |
+
batch_size=args.batch_size,
|
110 |
+
sampler=test_sampler,
|
111 |
+
num_workers=args.workers,
|
112 |
+
pin_memory=True,
|
113 |
+
)
|
114 |
+
|
115 |
+
print("Creating model")
|
116 |
+
|
117 |
+
criterion = nn.CrossEntropyLoss()
|
118 |
+
|
119 |
+
model = None
|
120 |
+
if args.model_ckpt:
|
121 |
+
checkpoint = torch.load(args.model_ckpt, map_location="cpu")
|
122 |
+
model = checkpoint["model_ckpt"]
|
123 |
+
if "model_ema" in checkpoint:
|
124 |
+
state_dict = {}
|
125 |
+
for key, value in checkpoint["model_ema"].items():
|
126 |
+
if not "module." in key:
|
127 |
+
continue
|
128 |
+
state_dict[key.replace("module.", "")] = value
|
129 |
+
model.load_state_dict(state_dict)
|
130 |
+
model = model.to(device)
|
131 |
+
|
132 |
+
accuracy = evaluate(
|
133 |
+
model=model,
|
134 |
+
model_onnx_path=args.model_onnx,
|
135 |
+
criterion=criterion,
|
136 |
+
data_loader=data_loader_test,
|
137 |
+
device=device,
|
138 |
+
)
|
139 |
+
print(f"Model accuracy is: {accuracy}")
|
140 |
+
|
141 |
+
|
142 |
+
def get_args_parser(add_help=True):
|
143 |
+
import argparse
|
144 |
+
|
145 |
+
parser = argparse.ArgumentParser(
|
146 |
+
description="PyTorch Classification Training", add_help=add_help
|
147 |
+
)
|
148 |
+
|
149 |
+
parser.add_argument(
|
150 |
+
"--data-path", default="datasets/imagenet", type=str, help="dataset path"
|
151 |
+
)
|
152 |
+
parser.add_argument(
|
153 |
+
"-b",
|
154 |
+
"--batch-size",
|
155 |
+
default=32,
|
156 |
+
type=int,
|
157 |
+
help="images per gpu, the total batch size is $NGPU x batch_size",
|
158 |
+
)
|
159 |
+
parser.add_argument(
|
160 |
+
"-j",
|
161 |
+
"--workers",
|
162 |
+
default=16,
|
163 |
+
type=int,
|
164 |
+
metavar="N",
|
165 |
+
help="number of data loading workers (default: 16)",
|
166 |
+
)
|
167 |
+
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
|
168 |
+
parser.add_argument(
|
169 |
+
"--model-onnx", default="", type=str, help="path of .onnx checkpoint"
|
170 |
+
)
|
171 |
+
parser.add_argument(
|
172 |
+
"--model-ckpt", default="", type=str, help="path of .pth checkpoint"
|
173 |
+
)
|
174 |
+
|
175 |
+
return parser
|
176 |
+
|
177 |
+
|
178 |
+
if __name__ == "__main__":
|
179 |
+
args = get_args_parser().parse_args()
|
180 |
+
main(args)
|
utils.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import time
|
3 |
+
from collections import defaultdict
|
4 |
+
from collections import deque
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.distributed as dist
|
8 |
+
|
9 |
+
|
10 |
+
class SmoothedValue:
|
11 |
+
"""Track a series of values and provide access to smoothed values over a
|
12 |
+
window or the global series average."""
|
13 |
+
|
14 |
+
def __init__(self, window_size=20, fmt=None):
|
15 |
+
if fmt is None:
|
16 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
17 |
+
self.deque = deque(maxlen=window_size)
|
18 |
+
self.total = 0.0
|
19 |
+
self.count = 0
|
20 |
+
self.fmt = fmt
|
21 |
+
|
22 |
+
def update(self, value, n=1):
|
23 |
+
self.deque.append(value)
|
24 |
+
self.count += n
|
25 |
+
self.total += value * n
|
26 |
+
|
27 |
+
def synchronize_between_processes(self):
|
28 |
+
"""
|
29 |
+
Warning: does not synchronize the deque!
|
30 |
+
"""
|
31 |
+
t = reduce_across_processes([self.count, self.total])
|
32 |
+
t = t.tolist()
|
33 |
+
self.count = int(t[0])
|
34 |
+
self.total = t[1]
|
35 |
+
|
36 |
+
@property
|
37 |
+
def median(self):
|
38 |
+
d = torch.tensor(list(self.deque))
|
39 |
+
return d.median().item()
|
40 |
+
|
41 |
+
@property
|
42 |
+
def avg(self):
|
43 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
44 |
+
return d.mean().item()
|
45 |
+
|
46 |
+
@property
|
47 |
+
def global_avg(self):
|
48 |
+
return self.total / self.count
|
49 |
+
|
50 |
+
@property
|
51 |
+
def max(self):
|
52 |
+
return max(self.deque)
|
53 |
+
|
54 |
+
@property
|
55 |
+
def value(self):
|
56 |
+
return self.deque[-1]
|
57 |
+
|
58 |
+
def __str__(self):
|
59 |
+
return self.fmt.format(
|
60 |
+
median=self.median,
|
61 |
+
avg=self.avg,
|
62 |
+
global_avg=self.global_avg,
|
63 |
+
max=self.max,
|
64 |
+
value=self.value,
|
65 |
+
)
|
66 |
+
|
67 |
+
|
68 |
+
class MetricLogger:
|
69 |
+
def __init__(self, delimiter="\t"):
|
70 |
+
self.meters = defaultdict(SmoothedValue)
|
71 |
+
self.delimiter = delimiter
|
72 |
+
|
73 |
+
def update(self, **kwargs):
|
74 |
+
for k, v in kwargs.items():
|
75 |
+
if isinstance(v, torch.Tensor):
|
76 |
+
v = v.item()
|
77 |
+
assert isinstance(v, (float, int))
|
78 |
+
self.meters[k].update(v)
|
79 |
+
|
80 |
+
def __getattr__(self, attr):
|
81 |
+
if attr in self.meters:
|
82 |
+
return self.meters[attr]
|
83 |
+
if attr in self.__dict__:
|
84 |
+
return self.__dict__[attr]
|
85 |
+
raise AttributeError(
|
86 |
+
f"'{type(self).__name__}' object has no attribute '{attr}'"
|
87 |
+
)
|
88 |
+
|
89 |
+
def __str__(self):
|
90 |
+
loss_str = []
|
91 |
+
for name, meter in self.meters.items():
|
92 |
+
loss_str.append(f"{name}: {str(meter)}")
|
93 |
+
return self.delimiter.join(loss_str)
|
94 |
+
|
95 |
+
def synchronize_between_processes(self):
|
96 |
+
for meter in self.meters.values():
|
97 |
+
meter.synchronize_between_processes()
|
98 |
+
|
99 |
+
def add_meter(self, name, meter):
|
100 |
+
self.meters[name] = meter
|
101 |
+
|
102 |
+
def log_every(self, iterable, print_freq, header=None):
|
103 |
+
i = 0
|
104 |
+
if not header:
|
105 |
+
header = ""
|
106 |
+
start_time = time.time()
|
107 |
+
end = time.time()
|
108 |
+
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
109 |
+
data_time = SmoothedValue(fmt="{avg:.4f}")
|
110 |
+
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
111 |
+
if torch.cuda.is_available():
|
112 |
+
log_msg = self.delimiter.join(
|
113 |
+
[
|
114 |
+
header,
|
115 |
+
"[{0" + space_fmt + "}/{1}]",
|
116 |
+
"eta: {eta}",
|
117 |
+
"{meters}",
|
118 |
+
"time: {time}",
|
119 |
+
"data: {data}",
|
120 |
+
"max mem: {memory:.0f}",
|
121 |
+
]
|
122 |
+
)
|
123 |
+
else:
|
124 |
+
log_msg = self.delimiter.join(
|
125 |
+
[
|
126 |
+
header,
|
127 |
+
"[{0" + space_fmt + "}/{1}]",
|
128 |
+
"eta: {eta}",
|
129 |
+
"{meters}",
|
130 |
+
"time: {time}",
|
131 |
+
"data: {data}",
|
132 |
+
]
|
133 |
+
)
|
134 |
+
MB = 1024.0 * 1024.0
|
135 |
+
for obj in iterable:
|
136 |
+
data_time.update(time.time() - end)
|
137 |
+
yield obj
|
138 |
+
iter_time.update(time.time() - end)
|
139 |
+
if i % print_freq == 0:
|
140 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
141 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
142 |
+
if torch.cuda.is_available():
|
143 |
+
print(
|
144 |
+
log_msg.format(
|
145 |
+
i,
|
146 |
+
len(iterable),
|
147 |
+
eta=eta_string,
|
148 |
+
meters=str(self),
|
149 |
+
time=str(iter_time),
|
150 |
+
data=str(data_time),
|
151 |
+
memory=torch.cuda.max_memory_allocated() / MB,
|
152 |
+
)
|
153 |
+
)
|
154 |
+
else:
|
155 |
+
print(
|
156 |
+
log_msg.format(
|
157 |
+
i,
|
158 |
+
len(iterable),
|
159 |
+
eta=eta_string,
|
160 |
+
meters=str(self),
|
161 |
+
time=str(iter_time),
|
162 |
+
data=str(data_time),
|
163 |
+
)
|
164 |
+
)
|
165 |
+
i += 1
|
166 |
+
end = time.time()
|
167 |
+
total_time = time.time() - start_time
|
168 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
169 |
+
print(f"{header} Total time: {total_time_str}")
|
170 |
+
|
171 |
+
|
172 |
+
def is_dist_avail_and_initialized():
|
173 |
+
if not dist.is_available():
|
174 |
+
return False
|
175 |
+
if not dist.is_initialized():
|
176 |
+
return False
|
177 |
+
return True
|
178 |
+
|
179 |
+
|
180 |
+
def reduce_across_processes(val):
|
181 |
+
if not is_dist_avail_and_initialized():
|
182 |
+
# nothing to sync, but we still convert to tensor for consistency with the distributed case.
|
183 |
+
return torch.tensor(val)
|
184 |
+
|
185 |
+
t = torch.tensor(val, device="cuda")
|
186 |
+
dist.barrier()
|
187 |
+
dist.all_reduce(t)
|
188 |
+
return t
|
189 |
+
|
190 |
+
|
191 |
+
def accuracy(output, target, topk=(1,)):
|
192 |
+
"""Computes the accuracy over the k top predictions for the specified
|
193 |
+
values of k."""
|
194 |
+
with torch.inference_mode():
|
195 |
+
maxk = max(topk)
|
196 |
+
batch_size = target.size(0)
|
197 |
+
if target.ndim == 2:
|
198 |
+
target = target.max(dim=1)[1]
|
199 |
+
|
200 |
+
_, pred = output.topk(maxk, 1, True, True)
|
201 |
+
pred = pred.t()
|
202 |
+
correct = pred.eq(target[None])
|
203 |
+
|
204 |
+
res = []
|
205 |
+
for k in topk:
|
206 |
+
correct_k = correct[:k].flatten().sum(dtype=torch.float32)
|
207 |
+
res.append(correct_k * (100.0 / batch_size))
|
208 |
+
return res
|