JeffreyXiang's picture
Upload
db6a3b7
raw
history blame
1.16 kB
import torch.nn as nn
from ..modules import sparse as sp
FP16_MODULES = (
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
nn.Linear,
sp.SparseConv3d,
sp.SparseInverseConv3d,
sp.SparseLinear,
)
def convert_module_to_f16(l):
"""
Convert primitive modules to float16.
"""
if isinstance(l, FP16_MODULES):
for p in l.parameters():
p.data = p.data.half()
def convert_module_to_f32(l):
"""
Convert primitive modules to float32, undoing convert_module_to_f16().
"""
if isinstance(l, FP16_MODULES):
for p in l.parameters():
p.data = p.data.float()
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)