File size: 3,600 Bytes
88677a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
# -*- coding: utf-8 -*-
# File : test_sync_batchnorm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
import unittest
import torch
import torch.nn as nn
from sync_batchnorm import (DataParallelWithCallback, SynchronizedBatchNorm1d,
SynchronizedBatchNorm2d)
from sync_batchnorm.unittest import TorchTestCase
from torch.autograd import Variable
def handy_var(a, unbias=True):
n = a.size(0)
asum = a.sum(dim=0)
as_sum = (a ** 2).sum(dim=0) # a square sum
sumvar = as_sum - asum * asum / n
if unbias:
return sumvar / (n - 1)
else:
return sumvar / n
def _find_bn(module):
for m in module.modules():
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)):
return m
class SyncTestCase(TorchTestCase):
def _syncParameters(self, bn1, bn2):
bn1.reset_parameters()
bn2.reset_parameters()
if bn1.affine and bn2.affine:
bn2.weight.data.copy_(bn1.weight.data)
bn2.bias.data.copy_(bn1.bias.data)
def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False):
"""Check the forward and backward for the customized batch normalization."""
bn1.train(mode=is_train)
bn2.train(mode=is_train)
if cuda:
input = input.cuda()
self._syncParameters(_find_bn(bn1), _find_bn(bn2))
input1 = Variable(input, requires_grad=True)
output1 = bn1(input1)
output1.sum().backward()
input2 = Variable(input, requires_grad=True)
output2 = bn2(input2)
output2.sum().backward()
self.assertTensorClose(input1.data, input2.data)
self.assertTensorClose(output1.data, output2.data)
self.assertTensorClose(input1.grad, input2.grad)
self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
def testSyncBatchNormNormalTrain(self):
bn = nn.BatchNorm1d(10)
sync_bn = SynchronizedBatchNorm1d(10)
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True)
def testSyncBatchNormNormalEval(self):
bn = nn.BatchNorm1d(10)
sync_bn = SynchronizedBatchNorm1d(10)
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False)
def testSyncBatchNormSyncTrain(self):
bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
bn.cuda()
sync_bn.cuda()
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True)
def testSyncBatchNormSyncEval(self):
bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
bn.cuda()
sync_bn.cuda()
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
def testSyncBatchNorm2DSyncTrain(self):
bn = nn.BatchNorm2d(10)
sync_bn = SynchronizedBatchNorm2d(10)
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
bn.cuda()
sync_bn.cuda()
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
if __name__ == '__main__':
unittest.main()
|