|
import math |
|
from typing import Literal, Optional |
|
|
|
import torch |
|
from torch.nn import functional as F |
|
|
|
from cublas_ops_ext import _simt_hgemv |
|
from cublas_ops_ext import cublas_hgemm_axbT as _cublas_hgemm_axbT |
|
from cublas_ops_ext import cublas_hgemm_batched_simple as _cublas_hgemm_batched_simple |
|
from cublas_ops_ext import ( |
|
cublaslt_hgemm_batched_simple as _cublaslt_hgemm_batched_simple, |
|
) |
|
from cublas_ops_ext import cublaslt_hgemm_simple as _cublaslt_hgemm_simple |
|
from torch import Tensor, nn |
|
|
|
global has_moved |
|
has_moved = {idx: False for idx in range(torch.cuda.device_count())} |
|
|
|
|
|
class StaticState: |
|
workspace = { |
|
idx: torch.empty((1024 * 1024 * 8,), dtype=torch.uint8) |
|
for idx in range(torch.cuda.device_count()) |
|
} |
|
workspace_size = workspace[0].nelement() |
|
bias_g = { |
|
idx: torch.tensor([], dtype=torch.float16) |
|
for idx in range(torch.cuda.device_count()) |
|
} |
|
|
|
@classmethod |
|
def get(cls, __name: str, device: torch.device) -> torch.Any: |
|
global has_moved |
|
idx = device.index if device.index is not None else 0 |
|
if not has_moved[idx]: |
|
cls.workspace[idx] = cls.workspace[idx].cuda(idx) |
|
cls.bias_g[idx] = cls.bias_g[idx].cuda(idx) |
|
has_moved[idx] = True |
|
if "bias" in __name: |
|
return cls.bias_g[idx] |
|
if "workspace" in __name: |
|
return cls.workspace[idx] |
|
if "workspace_size" in __name: |
|
return cls.workspace_size |
|
|
|
|
|
@torch.no_grad() |
|
def hgemv_simt(vec: torch.HalfTensor, mat: torch.HalfTensor, block_dim_x: int = 32): |
|
prev_dims = vec.shape[:-1] |
|
out = _simt_hgemv(mat, vec.view(-1, 1), block_dim_x=block_dim_x).view( |
|
*prev_dims, -1 |
|
) |
|
return out |
|
|
|
|
|
@torch.no_grad() |
|
def cublas_half_matmul_batched_simple(a: torch.Tensor, b: torch.Tensor): |
|
out = _cublas_hgemm_batched_simple(a, b) |
|
return out |
|
|
|
|
|
@torch.no_grad() |
|
def cublas_half_matmul_simple(a: torch.Tensor, b: torch.Tensor): |
|
out = _cublas_hgemm_axbT(b, a) |
|
return out |
|
|
|
|
|
@torch.no_grad() |
|
def cublaslt_fused_half_matmul_simple( |
|
a: torch.Tensor, |
|
b: torch.Tensor, |
|
bias: Optional[torch.Tensor] = None, |
|
epilogue_str: Optional[Literal["NONE", "RELU", "GELU"]] = "NONE", |
|
): |
|
if bias is None: |
|
bias = StaticState.get("bias", a.device) |
|
out = _cublaslt_hgemm_simple( |
|
a, b, bias, epilogue_str, StaticState.get("workspace", a.device) |
|
) |
|
return out |
|
|
|
|
|
@torch.no_grad() |
|
def cublaslt_fused_half_matmul_batched_simple( |
|
a: torch.Tensor, |
|
b: torch.Tensor, |
|
bias: Optional[torch.Tensor] = None, |
|
epilogue_str: Optional[Literal["NONE", "RELU", "GELU"]] = "NONE", |
|
): |
|
if bias is None: |
|
bias = StaticState.get("bias", a.device) |
|
out = _cublaslt_hgemm_batched_simple( |
|
a, b, bias, epilogue_str, StaticState.get("workspace", a.device) |
|
) |
|
return out |
|
|
|
|
|
class CublasLinear(nn.Linear): |
|
def __init__( |
|
self, |
|
in_features, |
|
out_features, |
|
bias=True, |
|
device=None, |
|
dtype=torch.float16, |
|
epilogue_str="NONE", |
|
): |
|
super().__init__( |
|
in_features, out_features, bias=bias, device=device, dtype=dtype |
|
) |
|
self._epilogue_str = epilogue_str |
|
self.has_bias = bias |
|
self.has_checked_weight = False |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
if not self.has_checked_weight: |
|
if not self.weight.dtype == torch.float16: |
|
self.to(dtype=torch.float16) |
|
self.has_checked_weight = True |
|
out_dtype = x.dtype |
|
needs_convert = out_dtype != torch.float16 |
|
if needs_convert: |
|
x = x.type(torch.float16) |
|
|
|
use_cublasLt = self.has_bias or self._epilogue_str != "NONE" |
|
if x.ndim == 1: |
|
x = x.unsqueeze(0) |
|
if math.prod(x.shape) == x.shape[-1]: |
|
out = F.linear(x, self.weight, bias=self.bias) |
|
if self._epilogue_str == "RELU": |
|
return F.relu(out) |
|
elif self._epilogue_str == "GELU": |
|
return F.gelu(out) |
|
if needs_convert: |
|
return out.type(out_dtype) |
|
return out |
|
if use_cublasLt: |
|
leading_dims = x.shape[:-1] |
|
x = x.reshape(-1, x.shape[-1]) |
|
out = cublaslt_fused_half_matmul_simple( |
|
x, self.weight, bias=self.bias.data, epilogue_str=self._epilogue_str |
|
) |
|
if needs_convert: |
|
return out.view(*leading_dims, out.shape[-1]).type(out_dtype) |
|
return out.view(*leading_dims, out.shape[-1]) |
|
else: |
|
leading_dims = x.shape[:-1] |
|
x = x.reshape(-1, x.shape[-1]) |
|
out = cublas_half_matmul_simple(x, self.weight) |
|
if needs_convert: |
|
return out.view(*leading_dims, out.shape[-1]).type(out_dtype) |
|
return out.view(*leading_dims, out.shape[-1]) |
|
|