Spaces:
Runtime error
Runtime error
Documentation
Browse files- README.md +15 -15
- dnnlib/__init__.py +1 -1
- docs/dataset-tool-help.txt +50 -50
- docs/train-help.txt +69 -69
- metrics/frechet_inception_distance.py +7 -2
- metrics/inception_score.py +6 -2
- metrics/kernel_inception_distance.py +7 -3
- metrics/perceptual_path_length.py +5 -1
- metrics/precision_recall.py +5 -1
- torch_utils/persistence.py +125 -21
- torch_utils/training_stats.py +149 -38
README.md
CHANGED
@@ -203,7 +203,7 @@ python dataset_tool.py --source=~/downloads/afhq/train/wild --dest=~/datasets/af
|
|
203 |
python dataset_tool.py --source=~/downloads/cifar-10-python.tar.gz --dest=~/datasets/cifar10.zip
|
204 |
```
|
205 |
|
206 |
-
**LSUN**: Download the desired
|
207 |
|
208 |
```.bash
|
209 |
python dataset_tool.py --source=~/downloads/lsun/raw/cat_lmdb --dest=~/datasets/lsuncat200k.zip \
|
@@ -262,7 +262,7 @@ The training configuration can be further customized with additional command lin
|
|
262 |
* `--cond=1` enables class-conditional training (requires a dataset with labels).
|
263 |
* `--mirror=1` amplifies the dataset with x-flips. Often beneficial, even with ADA.
|
264 |
* `--resume=ffhq1024 --snap=10` performs transfer learning from FFHQ trained at 1024x1024.
|
265 |
-
* `--resume=~/training-runs/<NAME>/network-snapshot-<
|
266 |
* `--gamma=10` overrides R1 gamma. We recommend trying a couple of different values for each new dataset.
|
267 |
* `--aug=ada --target=0.7` adjusts ADA target value (default: 0.6).
|
268 |
* `--augpipe=blit` enables pixel blitting but disables all other augmentations.
|
@@ -293,7 +293,7 @@ The total training time depends heavily on resolution, number of GPUs, dataset,
|
|
293 |
| 1024x1024 | 4 | 11h 36m | 12d 02h | 40.1–40.8 | 8.4 GB | 21.9 GB
|
294 |
| 1024x1024 | 8 | 5h 54m | 6d 03h | 20.2–20.6 | 8.3 GB | 44.7 GB
|
295 |
|
296 |
-
The above measurements were done using NVIDIA Tesla V100 GPUs with default settings (`--cfg=auto --aug=ada --metrics=fid50k_full`). "sec/kimg" shows the expected range of variation in raw training performance, as reported in `log.txt
|
297 |
|
298 |
In typical cases, 25000 kimg or more is needed to reach convergence, but the results are already quite reasonable around 5000 kimg. 1000 kimg is often enough for transfer learning, which tends to converge significantly faster. The following figure shows example convergence curves for different datasets as a function of wallclock time, using the same settings as above:
|
299 |
|
@@ -325,23 +325,23 @@ We employ the following metrics in the ADA paper. Execution time and GPU memory
|
|
325 |
|
326 |
| Metric | Time | GPU mem | Description |
|
327 |
| :----- | :----: | :-----: | :---------- |
|
328 |
-
| `fid50k_full` | 13 min | 1.8 GB | Fréchet inception distance<sup>[1]</sup> against the full dataset
|
329 |
-
| `kid50k_full` | 13 min | 1.8 GB | Kernel inception distance<sup>[2]</sup> against the full dataset
|
330 |
-
| `pr50k3_full` | 13 min | 4.1 GB | Precision and recall<sup>[3]</sup> againt the full dataset
|
331 |
-
| `is50k` | 13 min | 1.8 GB | Inception score<sup>[4]</sup> for CIFAR-10
|
332 |
|
333 |
In addition, the following metrics from the [StyleGAN](https://github.com/NVlabs/stylegan) and [StyleGAN2](https://github.com/NVlabs/stylegan2) papers are also supported:
|
334 |
|
335 |
| Metric | Time | GPU mem | Description |
|
336 |
| :------------ | :----: | :-----: | :---------- |
|
337 |
-
| `fid50k` | 13 min | 1.8 GB | Fréchet inception distance against 50k real images
|
338 |
-
| `kid50k` | 13 min | 1.8 GB | Kernel inception distance against 50k real images
|
339 |
-
| `pr50k3` | 13 min | 4.1 GB | Precision and recall against 50k real images
|
340 |
-
| `ppl2_wend` | 36 min | 2.4 GB | Perceptual path length<sup>[5]</sup> in W
|
341 |
-
| `ppl_zfull` | 36 min | 2.4 GB | Perceptual path length in Z
|
342 |
-
| `ppl_wfull` | 36 min | 2.4 GB | Perceptual path length in W
|
343 |
-
| `ppl_zend` | 36 min | 2.4 GB | Perceptual path length in Z
|
344 |
-
| `ppl_wend` | 36 min | 2.4 GB | Perceptual path length in W
|
345 |
|
346 |
References:
|
347 |
1. [GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium](https://arxiv.org/abs/1706.08500), Heusel et al. 2017
|
|
|
203 |
python dataset_tool.py --source=~/downloads/cifar-10-python.tar.gz --dest=~/datasets/cifar10.zip
|
204 |
```
|
205 |
|
206 |
+
**LSUN**: Download the desired categories from the [LSUN project page](https://www.yf.io/p/lsun/) and convert to ZIP archive:
|
207 |
|
208 |
```.bash
|
209 |
python dataset_tool.py --source=~/downloads/lsun/raw/cat_lmdb --dest=~/datasets/lsuncat200k.zip \
|
|
|
262 |
* `--cond=1` enables class-conditional training (requires a dataset with labels).
|
263 |
* `--mirror=1` amplifies the dataset with x-flips. Often beneficial, even with ADA.
|
264 |
* `--resume=ffhq1024 --snap=10` performs transfer learning from FFHQ trained at 1024x1024.
|
265 |
+
* `--resume=~/training-runs/<NAME>/network-snapshot-<INT>.pkl` resumes a previous training run.
|
266 |
* `--gamma=10` overrides R1 gamma. We recommend trying a couple of different values for each new dataset.
|
267 |
* `--aug=ada --target=0.7` adjusts ADA target value (default: 0.6).
|
268 |
* `--augpipe=blit` enables pixel blitting but disables all other augmentations.
|
|
|
293 |
| 1024x1024 | 4 | 11h 36m | 12d 02h | 40.1–40.8 | 8.4 GB | 21.9 GB
|
294 |
| 1024x1024 | 8 | 5h 54m | 6d 03h | 20.2–20.6 | 8.3 GB | 44.7 GB
|
295 |
|
296 |
+
The above measurements were done using NVIDIA Tesla V100 GPUs with default settings (`--cfg=auto --aug=ada --metrics=fid50k_full`). "sec/kimg" shows the expected range of variation in raw training performance, as reported in `log.txt`. "GPU mem" and "CPU mem" show the highest observed memory consumption, excluding the peak at the beginning caused by `torch.backends.cudnn.benchmark`.
|
297 |
|
298 |
In typical cases, 25000 kimg or more is needed to reach convergence, but the results are already quite reasonable around 5000 kimg. 1000 kimg is often enough for transfer learning, which tends to converge significantly faster. The following figure shows example convergence curves for different datasets as a function of wallclock time, using the same settings as above:
|
299 |
|
|
|
325 |
|
326 |
| Metric | Time | GPU mem | Description |
|
327 |
| :----- | :----: | :-----: | :---------- |
|
328 |
+
| `fid50k_full` | 13 min | 1.8 GB | Fréchet inception distance<sup>[1]</sup> against the full dataset
|
329 |
+
| `kid50k_full` | 13 min | 1.8 GB | Kernel inception distance<sup>[2]</sup> against the full dataset
|
330 |
+
| `pr50k3_full` | 13 min | 4.1 GB | Precision and recall<sup>[3]</sup> againt the full dataset
|
331 |
+
| `is50k` | 13 min | 1.8 GB | Inception score<sup>[4]</sup> for CIFAR-10
|
332 |
|
333 |
In addition, the following metrics from the [StyleGAN](https://github.com/NVlabs/stylegan) and [StyleGAN2](https://github.com/NVlabs/stylegan2) papers are also supported:
|
334 |
|
335 |
| Metric | Time | GPU mem | Description |
|
336 |
| :------------ | :----: | :-----: | :---------- |
|
337 |
+
| `fid50k` | 13 min | 1.8 GB | Fréchet inception distance against 50k real images
|
338 |
+
| `kid50k` | 13 min | 1.8 GB | Kernel inception distance against 50k real images
|
339 |
+
| `pr50k3` | 13 min | 4.1 GB | Precision and recall against 50k real images
|
340 |
+
| `ppl2_wend` | 36 min | 2.4 GB | Perceptual path length<sup>[5]</sup> in W, endpoints, full image
|
341 |
+
| `ppl_zfull` | 36 min | 2.4 GB | Perceptual path length in Z, full paths, cropped image
|
342 |
+
| `ppl_wfull` | 36 min | 2.4 GB | Perceptual path length in W, full paths, cropped image
|
343 |
+
| `ppl_zend` | 36 min | 2.4 GB | Perceptual path length in Z, endpoints, cropped image
|
344 |
+
| `ppl_wend` | 36 min | 2.4 GB | Perceptual path length in W, endpoints, cropped image
|
345 |
|
346 |
References:
|
347 |
1. [GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium](https://arxiv.org/abs/1706.08500), Heusel et al. 2017
|
dnnlib/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
#
|
3 |
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
# and proprietary rights in and to this software, related documentation
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
#
|
3 |
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
# and proprietary rights in and to this software, related documentation
|
docs/dataset-tool-help.txt
CHANGED
@@ -1,50 +1,50 @@
|
|
1 |
-
Usage: dataset_tool.py [OPTIONS]
|
2 |
-
|
3 |
-
Convert an image dataset into a dataset archive usable with StyleGAN2 ADA
|
4 |
-
PyTorch.
|
5 |
-
|
6 |
-
The input dataset format is guessed from the --source argument:
|
7 |
-
|
8 |
-
--source *_lmdb/ - Load LSUN dataset
|
9 |
-
--source cifar-10-python.tar.gz - Load CIFAR-10 dataset
|
10 |
-
--source path/ - Recursively load all images from path/
|
11 |
-
--source dataset.zip - Recursively load all images from dataset.zip
|
12 |
-
|
13 |
-
The output dataset format can be either an image folder or a zip archive.
|
14 |
-
Specifying the output format and path:
|
15 |
-
|
16 |
-
--dest /path/to/dir - Save output files under /path/to/dir
|
17 |
-
--dest /path/to/dataset.zip - Save output files into /path/to/dataset.zip archive
|
18 |
-
|
19 |
-
Images within the dataset archive will be stored as uncompressed PNG.
|
20 |
-
|
21 |
-
Image scale/crop and resolution requirements:
|
22 |
-
|
23 |
-
Output images must be square-shaped and they must all have the same power-
|
24 |
-
of-two dimensions.
|
25 |
-
|
26 |
-
To scale arbitrary input image size to a specific width and height, use
|
27 |
-
the --width and --height options. Output resolution will be either the
|
28 |
-
original input resolution (if --width/--height was not specified) or the
|
29 |
-
one specified with --width/height.
|
30 |
-
|
31 |
-
Use the --transform=center-crop or --transform=center-crop-wide options to
|
32 |
-
apply a center crop transform on the input image. These options should be
|
33 |
-
used with the --width and --height options. For example:
|
34 |
-
|
35 |
-
python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \
|
36 |
-
--transform=center-crop-wide --width 512 --height=384
|
37 |
-
|
38 |
-
Options:
|
39 |
-
--source PATH Directory or archive name for input dataset
|
40 |
-
[required]
|
41 |
-
--dest PATH Output directory or archive name for output
|
42 |
-
dataset [required]
|
43 |
-
--max-images INTEGER Output only up to `max-images` images
|
44 |
-
--resize-filter [box|lanczos] Filter to use when resizing images for
|
45 |
-
output resolution [default: lanczos]
|
46 |
-
--transform [center-crop|center-crop-wide]
|
47 |
-
Input crop/resize mode
|
48 |
-
--width INTEGER Output width
|
49 |
-
--height INTEGER Output height
|
50 |
-
--help Show this message and exit.
|
|
|
1 |
+
Usage: dataset_tool.py [OPTIONS]
|
2 |
+
|
3 |
+
Convert an image dataset into a dataset archive usable with StyleGAN2 ADA
|
4 |
+
PyTorch.
|
5 |
+
|
6 |
+
The input dataset format is guessed from the --source argument:
|
7 |
+
|
8 |
+
--source *_lmdb/ - Load LSUN dataset
|
9 |
+
--source cifar-10-python.tar.gz - Load CIFAR-10 dataset
|
10 |
+
--source path/ - Recursively load all images from path/
|
11 |
+
--source dataset.zip - Recursively load all images from dataset.zip
|
12 |
+
|
13 |
+
The output dataset format can be either an image folder or a zip archive.
|
14 |
+
Specifying the output format and path:
|
15 |
+
|
16 |
+
--dest /path/to/dir - Save output files under /path/to/dir
|
17 |
+
--dest /path/to/dataset.zip - Save output files into /path/to/dataset.zip archive
|
18 |
+
|
19 |
+
Images within the dataset archive will be stored as uncompressed PNG.
|
20 |
+
|
21 |
+
Image scale/crop and resolution requirements:
|
22 |
+
|
23 |
+
Output images must be square-shaped and they must all have the same power-
|
24 |
+
of-two dimensions.
|
25 |
+
|
26 |
+
To scale arbitrary input image size to a specific width and height, use
|
27 |
+
the --width and --height options. Output resolution will be either the
|
28 |
+
original input resolution (if --width/--height was not specified) or the
|
29 |
+
one specified with --width/height.
|
30 |
+
|
31 |
+
Use the --transform=center-crop or --transform=center-crop-wide options to
|
32 |
+
apply a center crop transform on the input image. These options should be
|
33 |
+
used with the --width and --height options. For example:
|
34 |
+
|
35 |
+
python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \
|
36 |
+
--transform=center-crop-wide --width 512 --height=384
|
37 |
+
|
38 |
+
Options:
|
39 |
+
--source PATH Directory or archive name for input dataset
|
40 |
+
[required]
|
41 |
+
--dest PATH Output directory or archive name for output
|
42 |
+
dataset [required]
|
43 |
+
--max-images INTEGER Output only up to `max-images` images
|
44 |
+
--resize-filter [box|lanczos] Filter to use when resizing images for
|
45 |
+
output resolution [default: lanczos]
|
46 |
+
--transform [center-crop|center-crop-wide]
|
47 |
+
Input crop/resize mode
|
48 |
+
--width INTEGER Output width
|
49 |
+
--height INTEGER Output height
|
50 |
+
--help Show this message and exit.
|
docs/train-help.txt
CHANGED
@@ -1,69 +1,69 @@
|
|
1 |
-
Usage: train.py [OPTIONS]
|
2 |
-
|
3 |
-
Train a GAN using the techniques described in the paper "Training
|
4 |
-
Generative Adversarial Networks with Limited Data".
|
5 |
-
|
6 |
-
Examples:
|
7 |
-
|
8 |
-
# Train with custom images using 1 GPU.
|
9 |
-
python train.py --outdir=~/training-runs --data=~/my-image-folder
|
10 |
-
|
11 |
-
# Train class-conditional CIFAR-10 using 2 GPUs.
|
12 |
-
python train.py --outdir=~/training-runs --data=~/datasets/cifar10.zip \
|
13 |
-
--gpus=2 --cfg=cifar --cond=1
|
14 |
-
|
15 |
-
# Transfer learn MetFaces from FFHQ using 4 GPUs.
|
16 |
-
python train.py --outdir=~/training-runs --data=~/datasets/metfaces.zip \
|
17 |
-
--gpus=4 --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10
|
18 |
-
|
19 |
-
# Reproduce original StyleGAN2 config F.
|
20 |
-
python train.py --outdir=~/training-runs --data=~/datasets/ffhq.zip \
|
21 |
-
--gpus=8 --cfg=stylegan2 --mirror=1 --aug=noaug
|
22 |
-
|
23 |
-
Base configs (--cfg):
|
24 |
-
auto Automatically select reasonable defaults based on resolution
|
25 |
-
and GPU count. Good starting point for new datasets.
|
26 |
-
stylegan2 Reproduce results for StyleGAN2 config F at 1024x1024.
|
27 |
-
paper256 Reproduce results for FFHQ and LSUN Cat at 256x256.
|
28 |
-
paper512 Reproduce results for BreCaHAD and AFHQ at 512x512.
|
29 |
-
paper1024 Reproduce results for MetFaces at 1024x1024.
|
30 |
-
cifar Reproduce results for CIFAR-10 at 32x32.
|
31 |
-
|
32 |
-
Transfer learning source networks (--resume):
|
33 |
-
ffhq256 FFHQ trained at 256x256 resolution.
|
34 |
-
ffhq512 FFHQ trained at 512x512 resolution.
|
35 |
-
ffhq1024 FFHQ trained at 1024x1024 resolution.
|
36 |
-
celebahq256 CelebA-HQ trained at 256x256 resolution.
|
37 |
-
lsundog256 LSUN Dog trained at 256x256 resolution.
|
38 |
-
<PATH or URL> Custom network pickle.
|
39 |
-
|
40 |
-
Options:
|
41 |
-
--outdir DIR Where to save the results [required]
|
42 |
-
--gpus INT Number of GPUs to use [default: 1]
|
43 |
-
--snap INT Snapshot interval [default: 50 ticks]
|
44 |
-
--metrics LIST Comma-separated list or "none" [default:
|
45 |
-
fid50k_full]
|
46 |
-
--seed INT Random seed [default: 0]
|
47 |
-
-n, --dry-run Print training options and exit
|
48 |
-
--data PATH Training data (directory or zip) [required]
|
49 |
-
--cond BOOL Train conditional model based on dataset
|
50 |
-
labels [default: false]
|
51 |
-
--subset INT Train with only N images [default: all]
|
52 |
-
--mirror BOOL Enable dataset x-flips [default: false]
|
53 |
-
--cfg [auto|stylegan2|paper256|paper512|paper1024|cifar]
|
54 |
-
Base config [default: auto]
|
55 |
-
--gamma FLOAT Override R1 gamma
|
56 |
-
--kimg INT Override training duration
|
57 |
-
--batch INT Override batch size
|
58 |
-
--aug [noaug|ada|fixed] Augmentation mode [default: ada]
|
59 |
-
--p FLOAT Augmentation probability for --aug=fixed
|
60 |
-
--target FLOAT ADA target value for --aug=ada
|
61 |
-
--augpipe [blit|geom|color|filter|noise|cutout|bg|bgc|bgcf|bgcfn|bgcfnc]
|
62 |
-
Augmentation pipeline [default: bgc]
|
63 |
-
--resume PKL Resume training [default: noresume]
|
64 |
-
--freezed INT Freeze-D [default: 0 layers]
|
65 |
-
--fp32 BOOL Disable mixed-precision training
|
66 |
-
--nhwc BOOL Use NHWC memory format with FP16
|
67 |
-
--nobench BOOL Disable cuDNN benchmarking
|
68 |
-
--workers INT Override number of DataLoader workers
|
69 |
-
--help Show this message and exit.
|
|
|
1 |
+
Usage: train.py [OPTIONS]
|
2 |
+
|
3 |
+
Train a GAN using the techniques described in the paper "Training
|
4 |
+
Generative Adversarial Networks with Limited Data".
|
5 |
+
|
6 |
+
Examples:
|
7 |
+
|
8 |
+
# Train with custom images using 1 GPU.
|
9 |
+
python train.py --outdir=~/training-runs --data=~/my-image-folder
|
10 |
+
|
11 |
+
# Train class-conditional CIFAR-10 using 2 GPUs.
|
12 |
+
python train.py --outdir=~/training-runs --data=~/datasets/cifar10.zip \
|
13 |
+
--gpus=2 --cfg=cifar --cond=1
|
14 |
+
|
15 |
+
# Transfer learn MetFaces from FFHQ using 4 GPUs.
|
16 |
+
python train.py --outdir=~/training-runs --data=~/datasets/metfaces.zip \
|
17 |
+
--gpus=4 --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10
|
18 |
+
|
19 |
+
# Reproduce original StyleGAN2 config F.
|
20 |
+
python train.py --outdir=~/training-runs --data=~/datasets/ffhq.zip \
|
21 |
+
--gpus=8 --cfg=stylegan2 --mirror=1 --aug=noaug
|
22 |
+
|
23 |
+
Base configs (--cfg):
|
24 |
+
auto Automatically select reasonable defaults based on resolution
|
25 |
+
and GPU count. Good starting point for new datasets.
|
26 |
+
stylegan2 Reproduce results for StyleGAN2 config F at 1024x1024.
|
27 |
+
paper256 Reproduce results for FFHQ and LSUN Cat at 256x256.
|
28 |
+
paper512 Reproduce results for BreCaHAD and AFHQ at 512x512.
|
29 |
+
paper1024 Reproduce results for MetFaces at 1024x1024.
|
30 |
+
cifar Reproduce results for CIFAR-10 at 32x32.
|
31 |
+
|
32 |
+
Transfer learning source networks (--resume):
|
33 |
+
ffhq256 FFHQ trained at 256x256 resolution.
|
34 |
+
ffhq512 FFHQ trained at 512x512 resolution.
|
35 |
+
ffhq1024 FFHQ trained at 1024x1024 resolution.
|
36 |
+
celebahq256 CelebA-HQ trained at 256x256 resolution.
|
37 |
+
lsundog256 LSUN Dog trained at 256x256 resolution.
|
38 |
+
<PATH or URL> Custom network pickle.
|
39 |
+
|
40 |
+
Options:
|
41 |
+
--outdir DIR Where to save the results [required]
|
42 |
+
--gpus INT Number of GPUs to use [default: 1]
|
43 |
+
--snap INT Snapshot interval [default: 50 ticks]
|
44 |
+
--metrics LIST Comma-separated list or "none" [default:
|
45 |
+
fid50k_full]
|
46 |
+
--seed INT Random seed [default: 0]
|
47 |
+
-n, --dry-run Print training options and exit
|
48 |
+
--data PATH Training data (directory or zip) [required]
|
49 |
+
--cond BOOL Train conditional model based on dataset
|
50 |
+
labels [default: false]
|
51 |
+
--subset INT Train with only N images [default: all]
|
52 |
+
--mirror BOOL Enable dataset x-flips [default: false]
|
53 |
+
--cfg [auto|stylegan2|paper256|paper512|paper1024|cifar]
|
54 |
+
Base config [default: auto]
|
55 |
+
--gamma FLOAT Override R1 gamma
|
56 |
+
--kimg INT Override training duration
|
57 |
+
--batch INT Override batch size
|
58 |
+
--aug [noaug|ada|fixed] Augmentation mode [default: ada]
|
59 |
+
--p FLOAT Augmentation probability for --aug=fixed
|
60 |
+
--target FLOAT ADA target value for --aug=ada
|
61 |
+
--augpipe [blit|geom|color|filter|noise|cutout|bg|bgc|bgcf|bgcfn|bgcfnc]
|
62 |
+
Augmentation pipeline [default: bgc]
|
63 |
+
--resume PKL Resume training [default: noresume]
|
64 |
+
--freezed INT Freeze-D [default: 0 layers]
|
65 |
+
--fp32 BOOL Disable mixed-precision training
|
66 |
+
--nhwc BOOL Use NHWC memory format with FP16
|
67 |
+
--nobench BOOL Disable cuDNN benchmarking
|
68 |
+
--workers INT Override number of DataLoader workers
|
69 |
+
--help Show this message and exit.
|
metrics/frechet_inception_distance.py
CHANGED
@@ -6,16 +6,21 @@
|
|
6 |
# distribution of this software and related documentation without an express
|
7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
|
|
|
|
|
|
|
|
|
|
|
9 |
import numpy as np
|
10 |
import scipy.linalg
|
11 |
-
|
12 |
from . import metric_utils
|
13 |
|
14 |
#----------------------------------------------------------------------------
|
15 |
|
16 |
def compute_fid(opts, max_real, num_gen):
|
|
|
17 |
detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
|
18 |
-
detector_kwargs = dict(return_features=True)
|
19 |
|
20 |
mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
|
21 |
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
|
|
6 |
# distribution of this software and related documentation without an express
|
7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
|
9 |
+
"""Frechet Inception Distance (FID) from the paper
|
10 |
+
"GANs trained by a two time-scale update rule converge to a local Nash
|
11 |
+
equilibrium". Matches the original implementation by Heusel et al. at
|
12 |
+
https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
|
13 |
+
|
14 |
import numpy as np
|
15 |
import scipy.linalg
|
|
|
16 |
from . import metric_utils
|
17 |
|
18 |
#----------------------------------------------------------------------------
|
19 |
|
20 |
def compute_fid(opts, max_real, num_gen):
|
21 |
+
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
22 |
detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
|
23 |
+
detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
|
24 |
|
25 |
mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
|
26 |
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
metrics/inception_score.py
CHANGED
@@ -6,15 +6,19 @@
|
|
6 |
# distribution of this software and related documentation without an express
|
7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
|
9 |
-
|
|
|
|
|
10 |
|
|
|
11 |
from . import metric_utils
|
12 |
|
13 |
#----------------------------------------------------------------------------
|
14 |
|
15 |
def compute_is(opts, num_gen, num_splits):
|
|
|
16 |
detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
|
17 |
-
detector_kwargs = dict(no_output_bias=True)
|
18 |
|
19 |
gen_probs = metric_utils.compute_feature_stats_for_generator(
|
20 |
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
|
|
6 |
# distribution of this software and related documentation without an express
|
7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
|
9 |
+
"""Inception Score (IS) from the paper "Improved techniques for training
|
10 |
+
GANs". Matches the original implementation by Salimans et al. at
|
11 |
+
https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
|
12 |
|
13 |
+
import numpy as np
|
14 |
from . import metric_utils
|
15 |
|
16 |
#----------------------------------------------------------------------------
|
17 |
|
18 |
def compute_is(opts, num_gen, num_splits):
|
19 |
+
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
20 |
detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
|
21 |
+
detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
|
22 |
|
23 |
gen_probs = metric_utils.compute_feature_stats_for_generator(
|
24 |
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
metrics/kernel_inception_distance.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
#
|
3 |
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
# and proprietary rights in and to this software, related documentation
|
@@ -6,15 +6,19 @@
|
|
6 |
# distribution of this software and related documentation without an express
|
7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
|
9 |
-
|
|
|
|
|
10 |
|
|
|
11 |
from . import metric_utils
|
12 |
|
13 |
#----------------------------------------------------------------------------
|
14 |
|
15 |
def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
|
|
|
16 |
detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
|
17 |
-
detector_kwargs = dict(return_features=True)
|
18 |
|
19 |
real_features = metric_utils.compute_feature_stats_for_dataset(
|
20 |
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
#
|
3 |
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
# and proprietary rights in and to this software, related documentation
|
|
|
6 |
# distribution of this software and related documentation without an express
|
7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
|
9 |
+
"""Kernel Inception Distance (KID) from the paper "Demystifying MMD
|
10 |
+
GANs". Matches the original implementation by Binkowski et al. at
|
11 |
+
https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
|
12 |
|
13 |
+
import numpy as np
|
14 |
from . import metric_utils
|
15 |
|
16 |
#----------------------------------------------------------------------------
|
17 |
|
18 |
def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
|
19 |
+
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
20 |
detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
|
21 |
+
detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
|
22 |
|
23 |
real_features = metric_utils.compute_feature_stats_for_dataset(
|
24 |
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
metrics/perceptual_path_length.py
CHANGED
@@ -6,11 +6,15 @@
|
|
6 |
# distribution of this software and related documentation without an express
|
7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
|
|
|
|
|
|
|
|
|
|
|
9 |
import copy
|
10 |
import numpy as np
|
11 |
import torch
|
12 |
import dnnlib
|
13 |
-
|
14 |
from . import metric_utils
|
15 |
|
16 |
#----------------------------------------------------------------------------
|
|
|
6 |
# distribution of this software and related documentation without an express
|
7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
|
9 |
+
"""Perceptual Path Length (PPL) from the paper "A Style-Based Generator
|
10 |
+
Architecture for Generative Adversarial Networks". Matches the original
|
11 |
+
implementation by Karras et al. at
|
12 |
+
https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
|
13 |
+
|
14 |
import copy
|
15 |
import numpy as np
|
16 |
import torch
|
17 |
import dnnlib
|
|
|
18 |
from . import metric_utils
|
19 |
|
20 |
#----------------------------------------------------------------------------
|
metrics/precision_recall.py
CHANGED
@@ -6,8 +6,12 @@
|
|
6 |
# distribution of this software and related documentation without an express
|
7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
|
9 |
-
|
|
|
|
|
|
|
10 |
|
|
|
11 |
from . import metric_utils
|
12 |
|
13 |
#----------------------------------------------------------------------------
|
|
|
6 |
# distribution of this software and related documentation without an express
|
7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
|
9 |
+
"""Precision/Recall (PR) from the paper "Improved Precision and Recall
|
10 |
+
Metric for Assessing Generative Models". Matches the original implementation
|
11 |
+
by Kynkaanniemi et al. at
|
12 |
+
https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
|
13 |
|
14 |
+
import torch
|
15 |
from . import metric_utils
|
16 |
|
17 |
#----------------------------------------------------------------------------
|
torch_utils/persistence.py
CHANGED
@@ -6,6 +6,13 @@
|
|
6 |
# distribution of this software and related documentation without an express
|
7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
import sys
|
10 |
import pickle
|
11 |
import io
|
@@ -17,29 +24,70 @@ import dnnlib
|
|
17 |
|
18 |
#----------------------------------------------------------------------------
|
19 |
|
20 |
-
_version
|
21 |
-
_decorators
|
22 |
-
_import_hooks
|
23 |
-
_module_to_src_dict = dict()
|
24 |
-
_src_to_module_dict = dict()
|
25 |
-
|
26 |
-
#----------------------------------------------------------------------------
|
27 |
-
|
28 |
-
def is_persistent(obj):
|
29 |
-
try:
|
30 |
-
if obj in _decorators:
|
31 |
-
return True
|
32 |
-
except TypeError:
|
33 |
-
pass
|
34 |
-
return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
|
35 |
-
|
36 |
-
def import_hook(func):
|
37 |
-
assert callable(func)
|
38 |
-
_import_hooks.append(func)
|
39 |
|
40 |
#----------------------------------------------------------------------------
|
41 |
|
42 |
def persistent_class(orig_class):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
assert isinstance(orig_class, type)
|
44 |
if is_persistent(orig_class):
|
45 |
return orig_class
|
@@ -83,7 +131,55 @@ def persistent_class(orig_class):
|
|
83 |
|
84 |
#----------------------------------------------------------------------------
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
def _reconstruct_persistent_obj(meta):
|
|
|
|
|
|
|
87 |
meta = dnnlib.EasyDict(meta)
|
88 |
meta.state = dnnlib.EasyDict(meta.state)
|
89 |
for hook in _import_hooks:
|
@@ -108,6 +204,8 @@ def _reconstruct_persistent_obj(meta):
|
|
108 |
#----------------------------------------------------------------------------
|
109 |
|
110 |
def _module_to_src(module):
|
|
|
|
|
111 |
src = _module_to_src_dict.get(module, None)
|
112 |
if src is None:
|
113 |
src = inspect.getsource(module)
|
@@ -116,6 +214,8 @@ def _module_to_src(module):
|
|
116 |
return src
|
117 |
|
118 |
def _src_to_module(src):
|
|
|
|
|
119 |
module = _src_to_module_dict.get(src, None)
|
120 |
if module is None:
|
121 |
module_name = "_imported_module_" + uuid.uuid4().hex
|
@@ -129,15 +229,19 @@ def _src_to_module(src):
|
|
129 |
#----------------------------------------------------------------------------
|
130 |
|
131 |
def _check_pickleable(obj):
|
|
|
|
|
|
|
|
|
132 |
def recurse(obj):
|
133 |
if isinstance(obj, (list, tuple, set)):
|
134 |
return [recurse(x) for x in obj]
|
135 |
if isinstance(obj, dict):
|
136 |
return [[recurse(x), recurse(y)] for x, y in obj.items()]
|
137 |
if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
|
138 |
-
return None #
|
139 |
if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
|
140 |
-
return None #
|
141 |
if is_persistent(obj):
|
142 |
return None # Persistent objects are pickleable, by virtue of the constructor check.
|
143 |
return obj
|
|
|
6 |
# distribution of this software and related documentation without an express
|
7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
|
9 |
+
"""Facilities for pickling Python code alongside other data.
|
10 |
+
|
11 |
+
The pickled code is automatically imported into a separate Python module
|
12 |
+
during unpickling. This way, any previously exported pickles will remain
|
13 |
+
usable even if the original code is no longer available, or if the current
|
14 |
+
version of the code is not consistent with what was originally pickled."""
|
15 |
+
|
16 |
import sys
|
17 |
import pickle
|
18 |
import io
|
|
|
24 |
|
25 |
#----------------------------------------------------------------------------
|
26 |
|
27 |
+
_version = 6 # internal version number
|
28 |
+
_decorators = set() # {decorator_class, ...}
|
29 |
+
_import_hooks = [] # [hook_function, ...]
|
30 |
+
_module_to_src_dict = dict() # {module: src, ...}
|
31 |
+
_src_to_module_dict = dict() # {src: module, ...}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
#----------------------------------------------------------------------------
|
34 |
|
35 |
def persistent_class(orig_class):
|
36 |
+
r"""Class decorator that extends a given class to save its source code
|
37 |
+
when pickled.
|
38 |
+
|
39 |
+
Example:
|
40 |
+
|
41 |
+
from torch_utils import persistence
|
42 |
+
|
43 |
+
@persistence.persistent_class
|
44 |
+
class MyNetwork(torch.nn.Module):
|
45 |
+
def __init__(self, num_inputs, num_outputs):
|
46 |
+
super().__init__()
|
47 |
+
self.fc = MyLayer(num_inputs, num_outputs)
|
48 |
+
...
|
49 |
+
|
50 |
+
@persistence.persistent_class
|
51 |
+
class MyLayer(torch.nn.Module):
|
52 |
+
...
|
53 |
+
|
54 |
+
When pickled, any instance of `MyNetwork` and `MyLayer` will save its
|
55 |
+
source code alongside other internal state (e.g., parameters, buffers,
|
56 |
+
and submodules). This way, any previously exported pickle will remain
|
57 |
+
usable even if the class definitions have been modified or are no
|
58 |
+
longer available.
|
59 |
+
|
60 |
+
The decorator saves the source code of the entire Python module
|
61 |
+
containing the decorated class. It does *not* save the source code of
|
62 |
+
any imported modules. Thus, the imported modules must be available
|
63 |
+
during unpickling, also including `torch_utils.persistence` itself.
|
64 |
+
|
65 |
+
It is ok to call functions defined in the same module from the
|
66 |
+
decorated class. However, if the decorated class depends on other
|
67 |
+
classes defined in the same module, they must be decorated as well.
|
68 |
+
This is illustrated in the above example in the case of `MyLayer`.
|
69 |
+
|
70 |
+
It is also possible to employ the decorator just-in-time before
|
71 |
+
calling the constructor. For example:
|
72 |
+
|
73 |
+
cls = MyLayer
|
74 |
+
if want_to_make_it_persistent:
|
75 |
+
cls = persistence.persistent_class(cls)
|
76 |
+
layer = cls(num_inputs, num_outputs)
|
77 |
+
|
78 |
+
As an additional feature, the decorator also keeps track of the
|
79 |
+
arguments that were used to construct each instance of the decorated
|
80 |
+
class. The arguments can be queried via `obj.init_args` and
|
81 |
+
`obj.init_kwargs`, and they are automatically pickled alongside other
|
82 |
+
object state. A typical use case is to first unpickle a previous
|
83 |
+
instance of a persistent class, and then upgrade it to use the latest
|
84 |
+
version of the source code:
|
85 |
+
|
86 |
+
with open('old_pickle.pkl', 'rb') as f:
|
87 |
+
old_net = pickle.load(f)
|
88 |
+
new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
|
89 |
+
misc.copy_params_and_buffers(old_net, new_net, require_all=True)
|
90 |
+
"""
|
91 |
assert isinstance(orig_class, type)
|
92 |
if is_persistent(orig_class):
|
93 |
return orig_class
|
|
|
131 |
|
132 |
#----------------------------------------------------------------------------
|
133 |
|
134 |
+
def is_persistent(obj):
|
135 |
+
r"""Test whether the given object or class is persistent, i.e.,
|
136 |
+
whether it will save its source code when pickled.
|
137 |
+
"""
|
138 |
+
try:
|
139 |
+
if obj in _decorators:
|
140 |
+
return True
|
141 |
+
except TypeError:
|
142 |
+
pass
|
143 |
+
return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
|
144 |
+
|
145 |
+
#----------------------------------------------------------------------------
|
146 |
+
|
147 |
+
def import_hook(hook):
|
148 |
+
r"""Register an import hook that is called whenever a persistent object
|
149 |
+
is being unpickled. A typical use case is to patch the pickled source
|
150 |
+
code to avoid errors and inconsistencies when the API of some imported
|
151 |
+
module has changed.
|
152 |
+
|
153 |
+
The hook should have the following signature:
|
154 |
+
|
155 |
+
hook(meta) -> modified meta
|
156 |
+
|
157 |
+
`meta` is an instance of `dnnlib.EasyDict` with the following fields:
|
158 |
+
|
159 |
+
type: Type of the persistent object, e.g. `'class'`.
|
160 |
+
version: Internal version number of `torch_utils.persistence`.
|
161 |
+
module_src Original source code of the Python module.
|
162 |
+
class_name: Class name in the original Python module.
|
163 |
+
state: Internal state of the object.
|
164 |
+
|
165 |
+
Example:
|
166 |
+
|
167 |
+
@persistence.import_hook
|
168 |
+
def wreck_my_network(meta):
|
169 |
+
if meta.class_name == 'MyNetwork':
|
170 |
+
print('MyNetwork is being imported. I will wreck it!')
|
171 |
+
meta.module_src = meta.module_src.replace("True", "False")
|
172 |
+
return meta
|
173 |
+
"""
|
174 |
+
assert callable(hook)
|
175 |
+
_import_hooks.append(hook)
|
176 |
+
|
177 |
+
#----------------------------------------------------------------------------
|
178 |
+
|
179 |
def _reconstruct_persistent_obj(meta):
|
180 |
+
r"""Hook that is called internally by the `pickle` module to unpickle
|
181 |
+
a persistent object.
|
182 |
+
"""
|
183 |
meta = dnnlib.EasyDict(meta)
|
184 |
meta.state = dnnlib.EasyDict(meta.state)
|
185 |
for hook in _import_hooks:
|
|
|
204 |
#----------------------------------------------------------------------------
|
205 |
|
206 |
def _module_to_src(module):
|
207 |
+
r"""Query the source code of a given Python module.
|
208 |
+
"""
|
209 |
src = _module_to_src_dict.get(module, None)
|
210 |
if src is None:
|
211 |
src = inspect.getsource(module)
|
|
|
214 |
return src
|
215 |
|
216 |
def _src_to_module(src):
|
217 |
+
r"""Get or create a Python module for the given source code.
|
218 |
+
"""
|
219 |
module = _src_to_module_dict.get(src, None)
|
220 |
if module is None:
|
221 |
module_name = "_imported_module_" + uuid.uuid4().hex
|
|
|
229 |
#----------------------------------------------------------------------------
|
230 |
|
231 |
def _check_pickleable(obj):
|
232 |
+
r"""Check that the given object is pickleable, raising an exception if
|
233 |
+
it is not. This function is expected to be considerably more efficient
|
234 |
+
than actually pickling the object.
|
235 |
+
"""
|
236 |
def recurse(obj):
|
237 |
if isinstance(obj, (list, tuple, set)):
|
238 |
return [recurse(x) for x in obj]
|
239 |
if isinstance(obj, dict):
|
240 |
return [[recurse(x), recurse(y)] for x, y in obj.items()]
|
241 |
if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
|
242 |
+
return None # Python primitive types are pickleable.
|
243 |
if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
|
244 |
+
return None # NumPy arrays and PyTorch tensors are pickleable.
|
245 |
if is_persistent(obj):
|
246 |
return None # Persistent objects are pickleable, by virtue of the constructor check.
|
247 |
return obj
|
torch_utils/training_stats.py
CHANGED
@@ -6,6 +6,11 @@
|
|
6 |
# distribution of this software and related documentation without an express
|
7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
|
|
|
|
|
|
|
|
|
|
|
9 |
import re
|
10 |
import numpy as np
|
11 |
import torch
|
@@ -15,19 +20,31 @@ from . import misc
|
|
15 |
|
16 |
#----------------------------------------------------------------------------
|
17 |
|
18 |
-
_num_moments = 3 # [num_scalars,
|
19 |
_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
|
20 |
-
_counter_dtype = torch.float64 # Data type to use for the counters.
|
21 |
-
|
22 |
_rank = 0 # Rank of the current process.
|
23 |
_sync_device = None # Device to use for multiprocess communication. None = single-process.
|
24 |
_sync_called = False # Has _sync() been called yet?
|
25 |
-
_counters = dict() # Running
|
26 |
-
_cumulative = dict() # Cumulative
|
27 |
|
28 |
#----------------------------------------------------------------------------
|
29 |
|
30 |
def init_multiprocessing(rank, sync_device):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
global _rank, _sync_device
|
32 |
assert not _sync_called
|
33 |
_rank = rank
|
@@ -37,6 +54,28 @@ def init_multiprocessing(rank, sync_device):
|
|
37 |
|
38 |
@misc.profiled_function
|
39 |
def report(name, value):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
if name not in _counters:
|
41 |
_counters[name] = dict()
|
42 |
|
@@ -45,7 +84,11 @@ def report(name, value):
|
|
45 |
return value
|
46 |
|
47 |
elems = elems.detach().flatten().to(_reduce_dtype)
|
48 |
-
moments = torch.stack([
|
|
|
|
|
|
|
|
|
49 |
assert moments.ndim == 1 and moments.shape[0] == _num_moments
|
50 |
moments = moments.to(_counter_dtype)
|
51 |
|
@@ -58,45 +101,35 @@ def report(name, value):
|
|
58 |
#----------------------------------------------------------------------------
|
59 |
|
60 |
def report0(name, value):
|
|
|
|
|
|
|
|
|
61 |
report(name, value if _rank == 0 else [])
|
62 |
return value
|
63 |
|
64 |
#----------------------------------------------------------------------------
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
_sync_called = True
|
71 |
-
|
72 |
-
# Collect deltas within current rank.
|
73 |
-
deltas = []
|
74 |
-
device = _sync_device if _sync_device is not None else torch.device('cpu')
|
75 |
-
for name in names:
|
76 |
-
delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
|
77 |
-
for counter in _counters[name].values():
|
78 |
-
delta.add_(counter.to(device))
|
79 |
-
counter.copy_(torch.zeros_like(counter))
|
80 |
-
deltas.append(delta)
|
81 |
-
deltas = torch.stack(deltas)
|
82 |
-
|
83 |
-
# Sum deltas across ranks.
|
84 |
-
if _sync_device is not None:
|
85 |
-
torch.distributed.all_reduce(deltas)
|
86 |
-
|
87 |
-
# Update cumulative values.
|
88 |
-
deltas = deltas.cpu()
|
89 |
-
for idx, name in enumerate(names):
|
90 |
-
if name not in _cumulative:
|
91 |
-
_cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
92 |
-
_cumulative[name].add_(deltas[idx])
|
93 |
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
100 |
def __init__(self, regex='.*', keep_previous=True):
|
101 |
self._regex = re.compile(regex)
|
102 |
self._keep_previous = keep_previous
|
@@ -106,9 +139,24 @@ class Collector:
|
|
106 |
self._moments.clear()
|
107 |
|
108 |
def names(self):
|
|
|
|
|
|
|
109 |
return [name for name in _counters if self._regex.fullmatch(name)]
|
110 |
|
111 |
def update(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
if not self._keep_previous:
|
113 |
self._moments.clear()
|
114 |
for name, cumulative in _sync(self.names()):
|
@@ -120,22 +168,38 @@ class Collector:
|
|
120 |
self._moments[name] = delta
|
121 |
|
122 |
def _get_delta(self, name):
|
|
|
|
|
|
|
|
|
123 |
assert self._regex.fullmatch(name)
|
124 |
if name not in self._moments:
|
125 |
self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
126 |
return self._moments[name]
|
127 |
|
128 |
def num(self, name):
|
|
|
|
|
|
|
|
|
129 |
delta = self._get_delta(name)
|
130 |
return int(delta[0])
|
131 |
|
132 |
def mean(self, name):
|
|
|
|
|
|
|
|
|
133 |
delta = self._get_delta(name)
|
134 |
if int(delta[0]) == 0:
|
135 |
return float('nan')
|
136 |
return float(delta[1] / delta[0])
|
137 |
|
138 |
def std(self, name):
|
|
|
|
|
|
|
|
|
139 |
delta = self._get_delta(name)
|
140 |
if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
|
141 |
return float('nan')
|
@@ -146,12 +210,59 @@ class Collector:
|
|
146 |
return np.sqrt(max(raw_var - np.square(mean), 0))
|
147 |
|
148 |
def as_dict(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
stats = dnnlib.EasyDict()
|
150 |
for name in self.names():
|
151 |
stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
|
152 |
return stats
|
153 |
|
154 |
def __getitem__(self, name):
|
|
|
|
|
|
|
155 |
return self.mean(name)
|
156 |
|
157 |
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
# distribution of this software and related documentation without an express
|
7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
|
9 |
+
"""Facilities for reporting and collecting training statistics across
|
10 |
+
multiple processes and devices. The interface is designed to minimize
|
11 |
+
synchronization overhead as well as the amount of boilerplate in user
|
12 |
+
code."""
|
13 |
+
|
14 |
import re
|
15 |
import numpy as np
|
16 |
import torch
|
|
|
20 |
|
21 |
#----------------------------------------------------------------------------
|
22 |
|
23 |
+
_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
|
24 |
_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
|
25 |
+
_counter_dtype = torch.float64 # Data type to use for the internal counters.
|
|
|
26 |
_rank = 0 # Rank of the current process.
|
27 |
_sync_device = None # Device to use for multiprocess communication. None = single-process.
|
28 |
_sync_called = False # Has _sync() been called yet?
|
29 |
+
_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
|
30 |
+
_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
|
31 |
|
32 |
#----------------------------------------------------------------------------
|
33 |
|
34 |
def init_multiprocessing(rank, sync_device):
|
35 |
+
r"""Initializes `torch_utils.training_stats` for collecting statistics
|
36 |
+
across multiple processes.
|
37 |
+
|
38 |
+
This function must be called after
|
39 |
+
`torch.distributed.init_process_group()` and before `Collector.update()`.
|
40 |
+
The call is not necessary if multi-process collection is not needed.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
rank: Rank of the current process.
|
44 |
+
sync_device: PyTorch device to use for inter-process
|
45 |
+
communication, or None to disable multi-process
|
46 |
+
collection. Typically `torch.device('cuda', rank)`.
|
47 |
+
"""
|
48 |
global _rank, _sync_device
|
49 |
assert not _sync_called
|
50 |
_rank = rank
|
|
|
54 |
|
55 |
@misc.profiled_function
|
56 |
def report(name, value):
|
57 |
+
r"""Broadcasts the given set of scalars to all interested instances of
|
58 |
+
`Collector`, across device and process boundaries.
|
59 |
+
|
60 |
+
This function is expected to be extremely cheap and can be safely
|
61 |
+
called from anywhere in the training loop, loss function, or inside a
|
62 |
+
`torch.nn.Module`.
|
63 |
+
|
64 |
+
Warning: The current implementation expects the set of unique names to
|
65 |
+
be consistent across processes. Please make sure that `report()` is
|
66 |
+
called at least once for each unique name by each process, and in the
|
67 |
+
same order. If a given process has no scalars to broadcast, it can do
|
68 |
+
`report(name, [])` (empty list).
|
69 |
+
|
70 |
+
Args:
|
71 |
+
name: Arbitrary string specifying the name of the statistic.
|
72 |
+
Averages are accumulated separately for each unique name.
|
73 |
+
value: Arbitrary set of scalars. Can be a list, tuple,
|
74 |
+
NumPy array, PyTorch tensor, or Python scalar.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
The same `value` that was passed in.
|
78 |
+
"""
|
79 |
if name not in _counters:
|
80 |
_counters[name] = dict()
|
81 |
|
|
|
84 |
return value
|
85 |
|
86 |
elems = elems.detach().flatten().to(_reduce_dtype)
|
87 |
+
moments = torch.stack([
|
88 |
+
torch.ones_like(elems).sum(),
|
89 |
+
elems.sum(),
|
90 |
+
elems.square().sum(),
|
91 |
+
])
|
92 |
assert moments.ndim == 1 and moments.shape[0] == _num_moments
|
93 |
moments = moments.to(_counter_dtype)
|
94 |
|
|
|
101 |
#----------------------------------------------------------------------------
|
102 |
|
103 |
def report0(name, value):
|
104 |
+
r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
|
105 |
+
but ignores any scalars provided by the other processes.
|
106 |
+
See `report()` for further details.
|
107 |
+
"""
|
108 |
report(name, value if _rank == 0 else [])
|
109 |
return value
|
110 |
|
111 |
#----------------------------------------------------------------------------
|
112 |
|
113 |
+
class Collector:
|
114 |
+
r"""Collects the scalars broadcasted by `report()` and `report0()` and
|
115 |
+
computes their long-term averages (mean and standard deviation) over
|
116 |
+
user-defined periods of time.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
+
The averages are first collected into internal counters that are not
|
119 |
+
directly visible to the user. They are then copied to the user-visible
|
120 |
+
state as a result of calling `update()` and can then be queried using
|
121 |
+
`mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
|
122 |
+
internal counters for the next round, so that the user-visible state
|
123 |
+
effectively reflects averages collected between the last two calls to
|
124 |
+
`update()`.
|
125 |
|
126 |
+
Args:
|
127 |
+
regex: Regular expression defining which statistics to
|
128 |
+
collect. The default is to collect everything.
|
129 |
+
keep_previous: Whether to retain the previous averages if no
|
130 |
+
scalars were collected on a given round
|
131 |
+
(default: True).
|
132 |
+
"""
|
133 |
def __init__(self, regex='.*', keep_previous=True):
|
134 |
self._regex = re.compile(regex)
|
135 |
self._keep_previous = keep_previous
|
|
|
139 |
self._moments.clear()
|
140 |
|
141 |
def names(self):
|
142 |
+
r"""Returns the names of all statistics broadcasted so far that
|
143 |
+
match the regular expression specified at construction time.
|
144 |
+
"""
|
145 |
return [name for name in _counters if self._regex.fullmatch(name)]
|
146 |
|
147 |
def update(self):
|
148 |
+
r"""Copies current values of the internal counters to the
|
149 |
+
user-visible state and resets them for the next round.
|
150 |
+
|
151 |
+
If `keep_previous=True` was specified at construction time, the
|
152 |
+
operation is skipped for statistics that have received no scalars
|
153 |
+
since the last update, retaining their previous averages.
|
154 |
+
|
155 |
+
This method performs a number of GPU-to-CPU transfers and one
|
156 |
+
`torch.distributed.all_reduce()`. It is intended to be called
|
157 |
+
periodically in the main training loop, typically once every
|
158 |
+
N training steps.
|
159 |
+
"""
|
160 |
if not self._keep_previous:
|
161 |
self._moments.clear()
|
162 |
for name, cumulative in _sync(self.names()):
|
|
|
168 |
self._moments[name] = delta
|
169 |
|
170 |
def _get_delta(self, name):
|
171 |
+
r"""Returns the raw moments that were accumulated for the given
|
172 |
+
statistic between the last two calls to `update()`, or zero if
|
173 |
+
no scalars were collected.
|
174 |
+
"""
|
175 |
assert self._regex.fullmatch(name)
|
176 |
if name not in self._moments:
|
177 |
self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
178 |
return self._moments[name]
|
179 |
|
180 |
def num(self, name):
|
181 |
+
r"""Returns the number of scalars that were accumulated for the given
|
182 |
+
statistic between the last two calls to `update()`, or zero if
|
183 |
+
no scalars were collected.
|
184 |
+
"""
|
185 |
delta = self._get_delta(name)
|
186 |
return int(delta[0])
|
187 |
|
188 |
def mean(self, name):
|
189 |
+
r"""Returns the mean of the scalars that were accumulated for the
|
190 |
+
given statistic between the last two calls to `update()`, or NaN if
|
191 |
+
no scalars were collected.
|
192 |
+
"""
|
193 |
delta = self._get_delta(name)
|
194 |
if int(delta[0]) == 0:
|
195 |
return float('nan')
|
196 |
return float(delta[1] / delta[0])
|
197 |
|
198 |
def std(self, name):
|
199 |
+
r"""Returns the standard deviation of the scalars that were
|
200 |
+
accumulated for the given statistic between the last two calls to
|
201 |
+
`update()`, or NaN if no scalars were collected.
|
202 |
+
"""
|
203 |
delta = self._get_delta(name)
|
204 |
if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
|
205 |
return float('nan')
|
|
|
210 |
return np.sqrt(max(raw_var - np.square(mean), 0))
|
211 |
|
212 |
def as_dict(self):
|
213 |
+
r"""Returns the averages accumulated between the last two calls to
|
214 |
+
`update()` as an `dnnlib.EasyDict`. The contents are as follows:
|
215 |
+
|
216 |
+
dnnlib.EasyDict(
|
217 |
+
NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
|
218 |
+
...
|
219 |
+
)
|
220 |
+
"""
|
221 |
stats = dnnlib.EasyDict()
|
222 |
for name in self.names():
|
223 |
stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
|
224 |
return stats
|
225 |
|
226 |
def __getitem__(self, name):
|
227 |
+
r"""Convenience getter.
|
228 |
+
`collector[name]` is a synonym for `collector.mean(name)`.
|
229 |
+
"""
|
230 |
return self.mean(name)
|
231 |
|
232 |
#----------------------------------------------------------------------------
|
233 |
+
|
234 |
+
def _sync(names):
|
235 |
+
r"""Synchronize the global cumulative counters across devices and
|
236 |
+
processes. Called internally by `Collector.update()`.
|
237 |
+
"""
|
238 |
+
if len(names) == 0:
|
239 |
+
return []
|
240 |
+
global _sync_called
|
241 |
+
_sync_called = True
|
242 |
+
|
243 |
+
# Collect deltas within current rank.
|
244 |
+
deltas = []
|
245 |
+
device = _sync_device if _sync_device is not None else torch.device('cpu')
|
246 |
+
for name in names:
|
247 |
+
delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
|
248 |
+
for counter in _counters[name].values():
|
249 |
+
delta.add_(counter.to(device))
|
250 |
+
counter.copy_(torch.zeros_like(counter))
|
251 |
+
deltas.append(delta)
|
252 |
+
deltas = torch.stack(deltas)
|
253 |
+
|
254 |
+
# Sum deltas across ranks.
|
255 |
+
if _sync_device is not None:
|
256 |
+
torch.distributed.all_reduce(deltas)
|
257 |
+
|
258 |
+
# Update cumulative values.
|
259 |
+
deltas = deltas.cpu()
|
260 |
+
for idx, name in enumerate(names):
|
261 |
+
if name not in _cumulative:
|
262 |
+
_cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
263 |
+
_cumulative[name].add_(deltas[idx])
|
264 |
+
|
265 |
+
# Return name-value pairs.
|
266 |
+
return [(name, _cumulative[name]) for name in names]
|
267 |
+
|
268 |
+
#----------------------------------------------------------------------------
|