Skip to content

Build flash-attention Wheels for Windows #3

Build flash-attention Wheels for Windows

Build flash-attention Wheels for Windows #3

Workflow file for this run

name: Build flash-attention Wheels for Windows
on:
workflow_dispatch:
inputs:
version:
description: 'Version tag of flash-attention to build: v2.3.4'
default: 'v2.3.4'
required: true
type: string
workflow_call:
inputs:
version:
description: 'Version tag of flash-attention to build: v2.3.4'
default: 'v2.3.4'
required: true
type: string
permissions:
contents: write
jobs:
build_wheels:
name: Build wheels for Python ${{ matrix.pyver }} and CUDA ${{ matrix.cuda }}
runs-on: windows-latest
strategy:
matrix:
pyver: ["3.8", "3.9", "3.10", "3.11"]
cuda: ["12.1.1"]
defaults:
run:
shell: pwsh
env:
CUDAVER: ${{ matrix.cuda }}
PCKGVER: ${{ inputs.version }}
steps:
- uses: actions/checkout@v4
with:
repository: 'Dao-AILab/flash-attention'
ref: ${{ inputs.version }}
submodules: 'recursive'
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.pyver }}
- name: Setup Mamba
uses: conda-incubator/[email protected]
with:
activate-environment: "build"
python-version: ${{ matrix.pyver }}
miniforge-variant: Mambaforge
miniforge-version: latest
use-mamba: true
add-pip-as-python-dependency: true
auto-activate-base: false
- name: Install Dependencies
run: |
$cudaVersion = $env:CUDAVER
$cudaVersionPytorch = $env:CUDAVER.Remove($env:CUDAVER.LastIndexOf('.')).Replace('.','')
$cudaChannels = ''
$cudaNum = [int]$cudaVersion.substring($cudaVersion.LastIndexOf('.')+1)
while ($cudaNum -ge 0) { $cudaChannels += '-c nvidia/label/cuda-' + $cudaVersion.Remove($cudaVersion.LastIndexOf('.')+1) + $cudaNum + ' '; $cudaNum-- }
mamba install -y 'cuda' $cudaChannels.TrimEnd().Split()
if (!(mamba list cuda)[-1].contains('cuda')) {sleep -s 10; mamba install -y 'cuda' $cudaChannels.TrimEnd().Split()}
if (!(mamba list cuda)[-1].contains('cuda')) {throw 'CUDA Toolkit failed to install!'}
python -m pip install --upgrade build setuptools wheel packaging ninja torch==2.1.0 --extra-index-url "https://download.pytorch.org/whl/cu$cudaVersionPytorch"
- name: Build Wheel
id: build-wheel
run: |
$cudaVersion = $env:CUDAVER.Remove($env:CUDAVER.LastIndexOf('.')).Replace('.','')
$packageVersion = $env:PCKGVER.TrimStart('v')
$env:CUDA_PATH = $env:CONDA_PREFIX
$env:CUDA_HOME = $env:CONDA_PREFIX
$env:MAX_JOBS = '1'
$env:FLASH_ATTENTION_FORCE_BUILD = 'TRUE'
$env:FLASH_ATTENTION_FORCE_SINGLE_THREAD = 'TRUE'
python -m build -n --wheel
$wheel = (gi '.\dist\*.whl')[0]
$wheelname = $wheel.name.replace("flash_attn-$packageVersion-","flash_attn-$packageVersion+cu$cudaVersion"+"torch2.1cxx11abiFALSE-")
Move-Item $wheel.fullname ".\dist\$wheelname"
- uses: actions/upload-artifact@v3
with:
name: 'windows-wheels'
path: ./dist/*.whl
- name: Upload files to a GitHub release
uses: svenstaro/[email protected]
continue-on-error: true
with:
file: ./dist/*.whl
tag: ${{ inputs.version }}
file_glob: true
overwrite: true
release_name: ${{ inputs.version }}