twodgirl commited on
Commit
6171612
1 Parent(s): b05b8f1

Create layer replacement.

Browse files
Files changed (1) hide show
  1. linear_8.py +87 -0
linear_8.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.models.normalization import RMSNorm
2
+ import torch
3
+ from torch import nn, Tensor
4
+
5
+
6
+ ###
7
+ # Code from aredden/flux-fp8-api
8
+
9
+ class Linear8(nn.Module):
10
+ __constants__ = ['in_features', 'out_features']
11
+
12
+ def __init__(self,
13
+ in_features: int, out_features: int, bias: bool = True,
14
+ device=None, dtype=None):
15
+ factory_kwargs = {'device': device, 'dtype': dtype}
16
+ super().__init__()
17
+ self.in_features = in_features
18
+ self.out_features = out_features
19
+ self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
20
+ if bias:
21
+ self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
22
+ else:
23
+ self.register_parameter('bias', None)
24
+ self.does_fp8 = self.supports_fp8_compute()
25
+ self.scale_weight = torch.ones(1, device=device, dtype=torch.float32)
26
+ self.scale_input = torch.ones(1, device=device, dtype=torch.float32)
27
+
28
+ def supports_fp8_compute(self, device=None):
29
+ props = torch.cuda.get_device_properties(device)
30
+ if props.major >= 9 or props.major == 8 and props.minor >= 9:
31
+ return True
32
+
33
+ return False
34
+
35
+ # def __setattr__(self, key, value):
36
+ # if isinstance(value, nn.Parameter):
37
+ # pass
38
+
39
+ def forward(self, x: Tensor) -> Tensor:
40
+ if self.does_fp8 is False:
41
+ return torch.nn.functional.linear(x, self.weight, self.bias)
42
+ dims = x.shape[:-1]
43
+ x = x.view(-1, self.in_features)
44
+ # Requires torch 2.4.
45
+ y = torch._scaled_mm(x.to(torch.float8_e4m3fn),
46
+ torch.transpose(self.weight, 0, 1),
47
+ scale_a=self.scale_input.to(device=x.device),
48
+ scale_b=self.scale_weight.to(device=x.device),
49
+ bias=self.bias.to(torch.bfloat16),
50
+ out_dtype=self.weight.dtype,
51
+ use_fast_accum=True)[0]
52
+
53
+ return y.view(*dims, self.out_features).to(torch.bfloat16)
54
+
55
+ ###
56
+ # Code from sayakpaul
57
+ # http://github.com/huggingface/diffusers/issues/6500
58
+
59
+ def replace_regular_linears(module, parent=''):
60
+ for name, child in module.named_children():
61
+ if isinstance(child, torch.nn.Linear):
62
+ in_features = child.in_features
63
+ out_features = child.out_features
64
+ device = child.weight.data.device
65
+ dtype = child.weight.data.dtype
66
+ has_bias = True if child.bias is not None else False
67
+ new_layer = Linear8(in_features,
68
+ out_features,
69
+ has_bias,
70
+ device,
71
+ dtype)
72
+ new_layer.load_state_dict(child.state_dict())
73
+ new_layer = new_layer.to(device)
74
+
75
+ setattr(module, name, new_layer)
76
+ elif isinstance(child, RMSNorm):
77
+ # RMSNorm doesn't support float8.
78
+ rsd = child.state_dict()
79
+ if 'weight' in rsd:
80
+ child.load_state_dict({'weight': rsd['weight'].to(torch.bfloat16)},
81
+ assign=True)
82
+ else:
83
+ # Recursively apply to child modules.
84
+ if parent == '':
85
+ replace_regular_linears(child, parent=name)
86
+ else:
87
+ replace_regular_linears(child, parent='.'.join([parent, name]))