Create layer replacement.
Browse files- linear_8.py +87 -0
linear_8.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers.models.normalization import RMSNorm
|
2 |
+
import torch
|
3 |
+
from torch import nn, Tensor
|
4 |
+
|
5 |
+
|
6 |
+
###
|
7 |
+
# Code from aredden/flux-fp8-api
|
8 |
+
|
9 |
+
class Linear8(nn.Module):
|
10 |
+
__constants__ = ['in_features', 'out_features']
|
11 |
+
|
12 |
+
def __init__(self,
|
13 |
+
in_features: int, out_features: int, bias: bool = True,
|
14 |
+
device=None, dtype=None):
|
15 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
16 |
+
super().__init__()
|
17 |
+
self.in_features = in_features
|
18 |
+
self.out_features = out_features
|
19 |
+
self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
|
20 |
+
if bias:
|
21 |
+
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
|
22 |
+
else:
|
23 |
+
self.register_parameter('bias', None)
|
24 |
+
self.does_fp8 = self.supports_fp8_compute()
|
25 |
+
self.scale_weight = torch.ones(1, device=device, dtype=torch.float32)
|
26 |
+
self.scale_input = torch.ones(1, device=device, dtype=torch.float32)
|
27 |
+
|
28 |
+
def supports_fp8_compute(self, device=None):
|
29 |
+
props = torch.cuda.get_device_properties(device)
|
30 |
+
if props.major >= 9 or props.major == 8 and props.minor >= 9:
|
31 |
+
return True
|
32 |
+
|
33 |
+
return False
|
34 |
+
|
35 |
+
# def __setattr__(self, key, value):
|
36 |
+
# if isinstance(value, nn.Parameter):
|
37 |
+
# pass
|
38 |
+
|
39 |
+
def forward(self, x: Tensor) -> Tensor:
|
40 |
+
if self.does_fp8 is False:
|
41 |
+
return torch.nn.functional.linear(x, self.weight, self.bias)
|
42 |
+
dims = x.shape[:-1]
|
43 |
+
x = x.view(-1, self.in_features)
|
44 |
+
# Requires torch 2.4.
|
45 |
+
y = torch._scaled_mm(x.to(torch.float8_e4m3fn),
|
46 |
+
torch.transpose(self.weight, 0, 1),
|
47 |
+
scale_a=self.scale_input.to(device=x.device),
|
48 |
+
scale_b=self.scale_weight.to(device=x.device),
|
49 |
+
bias=self.bias.to(torch.bfloat16),
|
50 |
+
out_dtype=self.weight.dtype,
|
51 |
+
use_fast_accum=True)[0]
|
52 |
+
|
53 |
+
return y.view(*dims, self.out_features).to(torch.bfloat16)
|
54 |
+
|
55 |
+
###
|
56 |
+
# Code from sayakpaul
|
57 |
+
# http://github.com/huggingface/diffusers/issues/6500
|
58 |
+
|
59 |
+
def replace_regular_linears(module, parent=''):
|
60 |
+
for name, child in module.named_children():
|
61 |
+
if isinstance(child, torch.nn.Linear):
|
62 |
+
in_features = child.in_features
|
63 |
+
out_features = child.out_features
|
64 |
+
device = child.weight.data.device
|
65 |
+
dtype = child.weight.data.dtype
|
66 |
+
has_bias = True if child.bias is not None else False
|
67 |
+
new_layer = Linear8(in_features,
|
68 |
+
out_features,
|
69 |
+
has_bias,
|
70 |
+
device,
|
71 |
+
dtype)
|
72 |
+
new_layer.load_state_dict(child.state_dict())
|
73 |
+
new_layer = new_layer.to(device)
|
74 |
+
|
75 |
+
setattr(module, name, new_layer)
|
76 |
+
elif isinstance(child, RMSNorm):
|
77 |
+
# RMSNorm doesn't support float8.
|
78 |
+
rsd = child.state_dict()
|
79 |
+
if 'weight' in rsd:
|
80 |
+
child.load_state_dict({'weight': rsd['weight'].to(torch.bfloat16)},
|
81 |
+
assign=True)
|
82 |
+
else:
|
83 |
+
# Recursively apply to child modules.
|
84 |
+
if parent == '':
|
85 |
+
replace_regular_linears(child, parent=name)
|
86 |
+
else:
|
87 |
+
replace_regular_linears(child, parent='.'.join([parent, name]))
|