|
from diffusers.models.normalization import RMSNorm |
|
import torch |
|
from torch import nn, Tensor |
|
|
|
|
|
|
|
|
|
|
|
class Linear8(nn.Module): |
|
__constants__ = ['in_features', 'out_features'] |
|
|
|
def __init__(self, |
|
in_features: int, out_features: int, bias: bool = True, |
|
device=None, dtype=None): |
|
factory_kwargs = {'device': device, 'dtype': dtype} |
|
super().__init__() |
|
self.in_features = in_features |
|
self.out_features = out_features |
|
self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs)) |
|
if bias: |
|
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) |
|
else: |
|
self.register_parameter('bias', None) |
|
self.does_fp8 = self.supports_fp8_compute() |
|
self.scale_weight = torch.ones(1, device=device, dtype=torch.float32) |
|
self.scale_input = torch.ones(1, device=device, dtype=torch.float32) |
|
|
|
def supports_fp8_compute(self, device=None): |
|
props = torch.cuda.get_device_properties(device) |
|
if props.major >= 9 or props.major == 8 and props.minor >= 9: |
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
if self.does_fp8 is False: |
|
return torch.nn.functional.linear(x, self.weight, self.bias) |
|
dims = x.shape[:-1] |
|
x = x.view(-1, self.in_features) |
|
|
|
y = torch._scaled_mm(x.to(torch.float8_e4m3fn), |
|
torch.transpose(self.weight, 0, 1), |
|
scale_a=self.scale_input.to(device=x.device), |
|
scale_b=self.scale_weight.to(device=x.device), |
|
bias=self.bias.to(torch.bfloat16), |
|
out_dtype=self.weight.dtype, |
|
use_fast_accum=True)[0] |
|
|
|
return y.view(*dims, self.out_features).to(torch.bfloat16) |
|
|
|
|
|
|
|
|
|
|
|
def replace_regular_linears(module, parent=''): |
|
for name, child in module.named_children(): |
|
if isinstance(child, torch.nn.Linear): |
|
in_features = child.in_features |
|
out_features = child.out_features |
|
device = child.weight.data.device |
|
dtype = child.weight.data.dtype |
|
has_bias = True if child.bias is not None else False |
|
new_layer = Linear8(in_features, |
|
out_features, |
|
has_bias, |
|
device, |
|
dtype) |
|
new_layer.load_state_dict(child.state_dict()) |
|
new_layer = new_layer.to(device) |
|
|
|
setattr(module, name, new_layer) |
|
elif isinstance(child, RMSNorm): |
|
|
|
rsd = child.state_dict() |
|
if 'weight' in rsd: |
|
child.load_state_dict({'weight': rsd['weight'].to(torch.bfloat16)}, |
|
assign=True) |
|
else: |
|
|
|
if parent == '': |
|
replace_regular_linears(child, parent=name) |
|
else: |
|
replace_regular_linears(child, parent='.'.join([parent, name])) |
|
|