twodgirl's picture
Create layer replacement.
6171612 verified
raw
history blame
No virus
3.45 kB
from diffusers.models.normalization import RMSNorm
import torch
from torch import nn, Tensor
###
# Code from aredden/flux-fp8-api
class Linear8(nn.Module):
__constants__ = ['in_features', 'out_features']
def __init__(self,
in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
if bias:
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter('bias', None)
self.does_fp8 = self.supports_fp8_compute()
self.scale_weight = torch.ones(1, device=device, dtype=torch.float32)
self.scale_input = torch.ones(1, device=device, dtype=torch.float32)
def supports_fp8_compute(self, device=None):
props = torch.cuda.get_device_properties(device)
if props.major >= 9 or props.major == 8 and props.minor >= 9:
return True
return False
# def __setattr__(self, key, value):
# if isinstance(value, nn.Parameter):
# pass
def forward(self, x: Tensor) -> Tensor:
if self.does_fp8 is False:
return torch.nn.functional.linear(x, self.weight, self.bias)
dims = x.shape[:-1]
x = x.view(-1, self.in_features)
# Requires torch 2.4.
y = torch._scaled_mm(x.to(torch.float8_e4m3fn),
torch.transpose(self.weight, 0, 1),
scale_a=self.scale_input.to(device=x.device),
scale_b=self.scale_weight.to(device=x.device),
bias=self.bias.to(torch.bfloat16),
out_dtype=self.weight.dtype,
use_fast_accum=True)[0]
return y.view(*dims, self.out_features).to(torch.bfloat16)
###
# Code from sayakpaul
# http://github.com/huggingface/diffusers/issues/6500
def replace_regular_linears(module, parent=''):
for name, child in module.named_children():
if isinstance(child, torch.nn.Linear):
in_features = child.in_features
out_features = child.out_features
device = child.weight.data.device
dtype = child.weight.data.dtype
has_bias = True if child.bias is not None else False
new_layer = Linear8(in_features,
out_features,
has_bias,
device,
dtype)
new_layer.load_state_dict(child.state_dict())
new_layer = new_layer.to(device)
setattr(module, name, new_layer)
elif isinstance(child, RMSNorm):
# RMSNorm doesn't support float8.
rsd = child.state_dict()
if 'weight' in rsd:
child.load_state_dict({'weight': rsd['weight'].to(torch.bfloat16)},
assign=True)
else:
# Recursively apply to child modules.
if parent == '':
replace_regular_linears(child, parent=name)
else:
replace_regular_linears(child, parent='.'.join([parent, name]))