Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import copy | |
import logging | |
import unittest | |
import torch | |
from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer | |
from omegaconf import OmegaConf | |
class TestGradientScaling(unittest.TestCase): | |
def setUp(self): | |
self.x = torch.tensor([2.0]).cuda().half() | |
weight = 3.0 | |
bias = 5.0 | |
self.error = 1.0 | |
self.target = torch.tensor([self.x * weight + bias + self.error]).cuda().half() | |
self.loss_fn = torch.nn.L1Loss() | |
self.model = torch.nn.Linear(1, 1) | |
self.model.weight.data = torch.tensor([[weight]]) | |
self.model.bias.data = torch.tensor([bias]) | |
self.model.cuda().half() | |
self.params = list(self.model.parameters()) | |
self.cfg_dls = OmegaConf.create( | |
{ | |
"optimization": { | |
"lr": [0.1], | |
}, | |
"optimizer": { | |
"_name": "adam", | |
"lr": [0.1], | |
"adam_betas": "(0.9, 0.999)", | |
"adam_eps": 1e-8, | |
"weight_decay": 0.0, | |
}, | |
"common": { | |
"fp16_init_scale": 1, | |
"fp16_scale_window": 1, | |
"fp16_scale_tolerance": 1, | |
"threshold_loss_scale": 1, | |
"min_loss_scale": 1e-4, | |
"tpu": False, | |
}, | |
} | |
) | |
logging.disable(logging.CRITICAL) | |
def tearDown(self): | |
logging.disable(logging.NOTSET) | |
def run_iter(self, model, params, optimizer): | |
optimizer.zero_grad() | |
y = model(self.x) | |
loss = self.loss_fn(y, self.target) | |
optimizer.backward(loss) | |
self.assertEqual(loss, torch.tensor(1.0, device="cuda:0", dtype=torch.float16)) | |
grad_norm = optimizer.clip_grad_norm(0) | |
self.assertAlmostEqual(grad_norm.item(), 2.2361, 4) | |
optimizer.step() | |
self.assertEqual( | |
model.weight, | |
torch.tensor( | |
[[3.0996]], device="cuda:0", dtype=torch.float16, requires_grad=True | |
), | |
) | |
self.assertEqual( | |
model.bias, | |
torch.tensor( | |
[5.1016], device="cuda:0", dtype=torch.float16, requires_grad=True | |
), | |
) | |
self.assertEqual(optimizer.scaler.loss_scale, 2.0) | |
def test_mixed_precision(self): | |
model = copy.deepcopy(self.model) | |
params = list(model.parameters()) | |
optimizer = FP16Optimizer.build_optimizer(self.cfg_dls, params) | |
self.run_iter(model, params, optimizer) | |
self.assertTrue( | |
all( | |
torch.all( | |
fp32_params.eq( | |
torch.tensor( | |
[3.1000, 5.1000], device="cuda:0", requires_grad=True | |
) | |
) | |
) | |
for fp32_params in optimizer.fp32_params.values() | |
) | |
) | |
def test_memory_efficient(self): | |
model = copy.deepcopy(self.model) | |
params = list(model.parameters()) | |
optimizer = MemoryEfficientFP16Optimizer.build_optimizer(self.cfg_dls, params) | |
self.run_iter(model, params, optimizer) | |
if __name__ == "__main__": | |
unittest.main() | |