|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`. |
|
|
|
Please refer to https://github.com/NVlabs/stylegan3 |
|
""" |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
def fma(a, b, c, impl='cuda'): |
|
if impl == 'cuda': |
|
return _FusedMultiplyAdd.apply(a, b, c) |
|
return torch.addcmul(c, a, b) |
|
|
|
|
|
|
|
class _FusedMultiplyAdd(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, a, b, c): |
|
out = torch.addcmul(c, a, b) |
|
ctx.save_for_backward(a, b) |
|
ctx.c_shape = c.shape |
|
return out |
|
|
|
@staticmethod |
|
def backward(ctx, dout): |
|
a, b = ctx.saved_tensors |
|
c_shape = ctx.c_shape |
|
da = None |
|
db = None |
|
dc = None |
|
|
|
if ctx.needs_input_grad[0]: |
|
da = _unbroadcast(dout * b, a.shape) |
|
|
|
if ctx.needs_input_grad[1]: |
|
db = _unbroadcast(dout * a, b.shape) |
|
|
|
if ctx.needs_input_grad[2]: |
|
dc = _unbroadcast(dout, c_shape) |
|
|
|
return da, db, dc |
|
|
|
|
|
|
|
def _unbroadcast(x, shape): |
|
extra_dims = x.ndim - len(shape) |
|
assert extra_dims >= 0 |
|
dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] |
|
if len(dim): |
|
x = x.sum(dim=dim, keepdim=True) |
|
if extra_dims: |
|
x = x.reshape(-1, *x.shape[extra_dims+1:]) |
|
assert x.shape == shape |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|