Skip to content

Build flash-attention Wheels for Windows #5

Build flash-attention Wheels for Windows

Build flash-attention Wheels for Windows #5

Workflow file for this run

name: Build flash-attention Wheels for Windows
on:
workflow_dispatch:
inputs:
version:
description: 'Version tag of flash-attention to build: v2.3.4'
default: 'v2.3.4'
required: true
type: string
workflow_call:
inputs:
version:
description: 'Version tag of flash-attention to build: v2.3.4'
default: 'v2.3.4'
required: true
type: string
permissions:
contents: write
jobs:
build_wheels:
name: Build wheels for Python ${{ matrix.pyver }}, CUDA ${{ matrix.cuda }}, and Torch ${{ matrix.torchver }}
runs-on: windows-latest
strategy:
matrix:
pyver: ["3.10", "3.11"]
cuda: ["12.2.2"]
torchver: ["2.1.2", "2.2.0"]
defaults:
run:
shell: pwsh
env:
CUDAVER: ${{ matrix.cuda }}
PCKGVER: ${{ inputs.version }}
steps:
- uses: actions/checkout@v4
with:
submodules: 'recursive'
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.pyver }}
- name: Setup Mamba
uses: conda-incubator/[email protected]
with:
activate-environment: "build"
python-version: ${{ matrix.pyver }}
miniforge-variant: Mambaforge
miniforge-version: latest
use-mamba: true
add-pip-as-python-dependency: true
auto-activate-base: false
- name: Install Dependencies
run: |
$cudaVersion = $env:CUDAVER
$cudaVersionPytorch = $env:CUDAVER.Remove($env:CUDAVER.LastIndexOf('.')).Replace('.','')
$cudaChannels = ''
$cudaNum = [int]$cudaVersion.substring($cudaVersion.LastIndexOf('.')+1)
while ($cudaNum -ge 0) { $cudaChannels += '-c nvidia/label/cuda-' + $cudaVersion.Remove($cudaVersion.LastIndexOf('.')+1) + $cudaNum + ' '; $cudaNum-- }
mamba install -y 'cuda' $cudaChannels.TrimEnd().Split()
if (!(mamba list cuda)[-1].contains('cuda')) {sleep -s 10; mamba install -y 'cuda' $cudaChannels.TrimEnd().Split()}
if (!(mamba list cuda)[-1].contains('cuda')) {throw 'CUDA Toolkit failed to install!'}
python -m pip install --upgrade build setuptools wheel packaging ninja psutil torch==${{ matrix.torchver }} --extra-index-url "https://download.pytorch.org/whl/cu121"
- name: Build Wheel
id: build-wheel
run: |
$cudaVersion = $env:CUDAVER.Remove($env:CUDAVER.LastIndexOf('.')).Replace('.','')
$packageVersion = $env:PCKGVER.TrimStart('v')
$env:CUDA_PATH = $env:CONDA_PREFIX
$env:CUDA_HOME = $env:CONDA_PREFIX
$env:MAX_JOBS = '1'
$env:FLASH_ATTENTION_FORCE_BUILD = 'TRUE'
$env:FLASH_ATTENTION_FORCE_SINGLE_THREAD = 'TRUE'
python -m build -n --wheel
$wheel = (gi '.\dist\*.whl')[0]
$wheelname = $wheel.name.replace("flash_attn-$packageVersion-","flash_attn-$packageVersion+cu$cudaVersion"+"torch${{ matrix.torchver }}cxx11abiFALSE-")
Move-Item $wheel.fullname ".\dist\$wheelname"
- uses: actions/upload-artifact@v3
with:
name: 'windows-wheels'
path: ./dist/*.whl
- name: Upload files to a GitHub release
uses: svenstaro/[email protected]
continue-on-error: true
with:
file: ./dist/*.whl
tag: ${{ inputs.version }}
file_glob: true
overwrite: true
release_name: ${{ inputs.version }}