File size: 3,453 Bytes
6171612
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from diffusers.models.normalization import RMSNorm
import torch
from torch import nn, Tensor


###
# Code from aredden/flux-fp8-api

class Linear8(nn.Module):
    __constants__ = ['in_features', 'out_features']

    def __init__(self,
                 in_features: int, out_features: int, bias: bool = True,
                 device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
        if bias:
            self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
        else:
            self.register_parameter('bias', None)
        self.does_fp8 = self.supports_fp8_compute()
        self.scale_weight = torch.ones(1, device=device, dtype=torch.float32)
        self.scale_input = torch.ones(1, device=device, dtype=torch.float32)

    def supports_fp8_compute(self, device=None):
        props = torch.cuda.get_device_properties(device)
        if props.major >= 9 or props.major == 8 and props.minor >= 9:
            return True

        return False

    # def __setattr__(self, key, value):
    #     if isinstance(value, nn.Parameter):
    #         pass

    def forward(self, x: Tensor) -> Tensor:
        if self.does_fp8 is False:
            return torch.nn.functional.linear(x, self.weight, self.bias)
        dims = x.shape[:-1]
        x = x.view(-1, self.in_features)
        # Requires torch 2.4.
        y = torch._scaled_mm(x.to(torch.float8_e4m3fn),
                             torch.transpose(self.weight, 0, 1),
                             scale_a=self.scale_input.to(device=x.device),
                             scale_b=self.scale_weight.to(device=x.device),
                             bias=self.bias.to(torch.bfloat16),
                             out_dtype=self.weight.dtype,
                             use_fast_accum=True)[0]

        return y.view(*dims, self.out_features).to(torch.bfloat16)

###
# Code from sayakpaul
# http://github.com/huggingface/diffusers/issues/6500

def replace_regular_linears(module, parent=''):
    for name, child in module.named_children():
        if isinstance(child, torch.nn.Linear):
            in_features = child.in_features
            out_features = child.out_features
            device = child.weight.data.device
            dtype = child.weight.data.dtype
            has_bias = True if child.bias is not None else False
            new_layer = Linear8(in_features,
                                out_features,
                                has_bias,
                                device,
                                dtype)
            new_layer.load_state_dict(child.state_dict())
            new_layer = new_layer.to(device)

            setattr(module, name, new_layer)
        elif isinstance(child, RMSNorm):
            # RMSNorm doesn't support float8.
            rsd = child.state_dict()
            if 'weight' in rsd:
                child.load_state_dict({'weight': rsd['weight'].to(torch.bfloat16)},
                                      assign=True)
        else:
            # Recursively apply to child modules.
            if parent == '':
                replace_regular_linears(child, parent=name)
            else:
                replace_regular_linears(child, parent='.'.join([parent, name]))