File size: 3,525 Bytes
29f5d34
 
 
3407de9
29f5d34
 
3407de9
 
29f5d34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3407de9
29f5d34
 
 
 
 
 
3407de9
29f5d34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
@echo off
setlocal enabledelayedexpansion

set MAX_JOBS=1

:parseArgs

rem Assigning a value to MAX_JOBS via a variable does not work in ninja, I don't know why

rem if [%1] == [WORKERS] set MAX_JOBS=%2 & shift & shift & goto :parseargs
if [%1] == [FORCE_CXX11_ABI] set FLASH_ATTENTION_FORCE_CXX11_ABI=%2 & shift & shift & goto :parseargs
goto :buildContinue
:end

:buildFinalize
set MAX_JOBS=
set BUILD_TARGET=
set DISTUTILS_USE_SDK=
set FLASH_ATTENTION_FORCE_BUILD=
set FLASH_ATTENTION_FORCE_CXX11_ABI=
set dist_dir=
set tmpname=
endlocal
goto :eof
:end

:buildContinue
echo MAX_JOBS: %MAX_JOBS%
echo FLASH_ATTENTION_FORCE_CXX11_ABI: %FLASH_ATTENTION_FORCE_CXX11_ABI%

rem # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6

rem # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810

rem # However this still fails so I'm using a newer version of setuptools

rem pip install setuptools==68.0.0
pip install "setuptools>=49.6.0" packaging wheel psutil

rem # Limit MAX_JOBS otherwise the github runner goes OOM

rem # CUDA 11.8 can compile with 2 jobs, but CUDA 12.3 goes OOM
set FLASH_ATTENTION_FORCE_BUILD=TRUE
set BUILD_TARGET=cuda
set DISTUTILS_USE_SDK=1
set dist_dir=dist

python setup.py bdist_wheel --dist-dir=%dist_dir%





rem rename whl



rem just major version, such as cu12torch24cxx11abiFALSE

rem for /f "delims=" %%i in ('python -c "import sys; from packaging.version import parse; import torch; python_version = f'cp{sys.version_info.major}{sys.version_info.minor}'; cxx11_abi=str(torch._C._GLIBCXX_USE_CXX11_ABI).upper(); torch_cuda_version = parse(torch.version.cuda); torch_cuda_version = parse(\"11.8\") if torch_cuda_version.major == 11 else parse(\"12.4\"); cuda_version = f'{torch_cuda_version.major}'; torch_version_raw = parse(torch.__version__); torch_version = f'{torch_version_raw.major}.{torch_version_raw.minor}'; wheel_filename = f'cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}'; print(wheel_filename);"') do set wheel_filename=%%i



rem such as cu124torch240cxx11abiFALSE
for /f "delims=" %%i in ('python -c "import sys; from packaging.version import parse; import torch; python_version = f'cp{sys.version_info.major}{sys.version_info.minor}'; cxx11_abi=str(torch._C._GLIBCXX_USE_CXX11_ABI).upper(); torch_cuda_version = parse(torch.version.cuda); cuda_version = \"\".join(map(str, torch_cuda_version.release)); torch_version_raw = parse(torch.__version__); torch_version = \".\".join(map(str, torch_version_raw.release)); wheel_filename = f'cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}'; print(wheel_filename);"') do set wheel_filename=%%i

set tmpname=%wheel_filename%


for %%i in (%dist_dir%\*.whl) do (
    set "filename=%%~nxi"

    

    rem check if contains +
    echo !filename! | findstr /c:+ >nul
    if errorlevel 1 (

        rem replace second '-' to wheel_filename
        set "count=0"
        for /l %%j in (0, 1, 1000) do (
            if "!filename:~%%j,1!"=="-" set /a count+=1
            if "!filename:~%%j,1!"=="-" if "!count!"=="2" (
                set "new_filename=!filename:~0,%%j!+%tmpname%!filename:~%%j!"

                echo Renaming !filename! to !new_filename!
                move "%%i" "!dist_dir!/!new_filename!"
                goto :next
            )
        )
    )
    :next

    rem continue
)

goto :buildFinalize
:end