diff --git a/.github/workflows/_build.yml b/.github/workflows/_build.yml new file mode 100644 index 00000000000..7b3890b5554 --- /dev/null +++ b/.github/workflows/_build.yml @@ -0,0 +1,229 @@ +name: ~Build wheel template + +on: + workflow_call: + inputs: + runs-on: + description: "The runner to use for the build" + required: true + type: string + python-version: + description: "The Python version to use for the build" + required: true + type: string + cuda-version: + description: "The CUDA version to use for the build" + required: true + type: string + torch-version: + description: "The PyTorch version to use for the build" + required: true + type: string + cxx11_abi: + description: "The C++11 ABI to use for the build" + required: true + type: string + upload-to-release: + description: "Upload wheel to this release" + required: false + type: boolean + default: false + release-version: + description: "Upload wheel to this release" + required: false + type: string + +defaults: + run: + shell: bash -x -e -u -o pipefail {0} + +jobs: + build-wheel: + runs-on: ${{ inputs.runs-on }} + name: Build wheel (${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}) + steps: + - name: Checkout + uses: actions/checkout@v5 + with: + ref: ${{ inputs.release-version }} + submodules: recursive + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python-version }} + + - name: Set CUDA and PyTorch versions + run: | + echo "MATRIX_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV + echo "MATRIX_TORCH_VERSION=$(echo ${{ inputs.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV + echo "WHEEL_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV + echo "MATRIX_PYTHON_VERSION=$(echo ${{ inputs.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV + + - name: Free up disk space + if: ${{ runner.os == 'Linux' }} + # https://github.com/easimon/maximize-build-space/blob/master/action.yml + # https://github.com/easimon/maximize-build-space/tree/test-report + run: | + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf /opt/hostedtoolcache/CodeQL + + - name: Set up swap space + if: runner.os == 'Linux' + uses: pierotofy/set-swap-space@v1.0 + with: + swap-size-gb: 10 + + - name: Install CUDA ${{ inputs.cuda-version }} + if: ${{ inputs.cuda-version != 'cpu' }} + uses: Jimver/cuda-toolkit@v0.2.30 + id: cuda-toolkit + with: + cuda: ${{ inputs.cuda-version }} + linux-local-args: '["--toolkit"]' + # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1 + # method: ${{ (inputs.cuda-version == '11.8.0' || inputs.cuda-version == '12.1.0') && 'network' || 'local' }} + method: "network" + # sub-packages: '["nvcc"]' + + - name: Install PyTorch ${{ inputs.torch-version }}+cu${{ inputs.cuda-version }} + run: | + pip install --upgrade pip + # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error + # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable + pip install typing-extensions==4.12.2 + # Pick the highest available PyTorch wheel CUDA version that doesn't exceed system CUDA + export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ + available = { \ + '2.6': [118, 124, 126], \ + '2.7': [118, 126, 128], \ + '2.8': [126, 128, 129], \ + '2.9': [126, 128, 130], \ + '2.10': [126, 128, 130], \ + '2.11': [126, 128, 130], \ + }[env['MATRIX_TORCH_VERSION']]; \ + sys_cuda = int(env['MATRIX_CUDA_VERSION']); \ + print(max(v for v in available if v <= sys_cuda))" \ + ) + # detect if we're on ARM + if [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then + PLAT=linux_aarch64 + else + PLAT=manylinux_2_27_x86_64.manylinux_2_28_x86_64 + fi + echo "PLAT=$PLAT" >> $GITHUB_ENV + if [[ ${{ inputs.torch-version }} == *"dev"* ]]; then + # pip install --no-cache-dir --pre torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} + # Can't use --no-deps because we need cudnn etc. + # Hard-coding this version of pytorch-triton for torch 2.9.0.dev20250904 + pip install jinja2 + TRITON_URL=https://download.pytorch.org/whl/nightly/pytorch_triton-3.4.0%2Bgitf7888497-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-${PLAT}.whl + TORCH_URL=https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ inputs.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-manylinux_2_28_$(uname -m).whl + pip install --no-cache-dir --pre "${TRITON_URL}" + pip install --no-cache-dir --pre "${TORCH_URL}" + else + pip install --no-cache-dir torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} + fi + nvcc --version + python --version + python -c "import torch; print('PyTorch:', torch.__version__)" + python -c "import torch; print('CUDA:', torch.version.cuda)" + python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" + + - name: Restore build cache + uses: actions/cache/restore@v4 + with: + path: build.tar + key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }} + restore-keys: | + build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}- + + - name: Unpack build cache + run: | + echo ::group::Adjust timestamps + sudo find / -exec touch -t 197001010000 {} + || true + echo ::endgroup:: + + if [ -f build.tar ]; then + find . -mindepth 1 -maxdepth 1 ! -name 'build.tar' -exec rm -rf {} + + tar -xpvf build.tar -C . + else + echo "No build.tar found, skipping" + fi + + ls -al ./ + ls -al build/ || true + ls -al csrc/ || true + + - name: Build wheel + id: build_wheel + run: | + # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 + # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 + # However this still fails so I'm using a newer version of setuptools + pip install setuptools==75.8.0 + pip install ninja packaging wheel + export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH + export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH + # Limit MAX_JOBS otherwise the github runner goes OOM + # nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM + + export MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "129" ] || [ "$MATRIX_CUDA_VERSION" == "130" ] && echo 1 || echo 2) + export NVCC_THREADS=2 + export FLASH_ATTENTION_FORCE_BUILD="TRUE" + export FLASH_ATTENTION_FORCE_CXX11_ABI=${{ inputs.cxx11_abi }} + + # 5h timeout since GH allows max 6h and we want some buffer + EXIT_CODE=0 + timeout 5h python setup.py bdist_wheel --dist-dir=dist || EXIT_CODE=$? + + if [ $EXIT_CODE -eq 0 ]; then + tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }} + wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") + ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} + echo "wheel_name=${wheel_name}" >> $GITHUB_ENV + fi + + # Store exit code in GitHub env for later steps + echo "build_exit_code=$EXIT_CODE" | tee -a "$GITHUB_OUTPUT" + + # Do not fail the job if timeout killed the build + exit $EXIT_CODE + + - name: Log build logs after timeout + if: always() && steps.build_wheel.outputs.build_exit_code == 124 + run: | + ls -al ./ + tar -cvf build.tar . --atime-preserve=replace + + - name: Save build cache timeout + if: always() && steps.build_wheel.outputs.build_exit_code == 124 + uses: actions/cache/save@v4 + with: + key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }} + path: build.tar + + - name: Log Built Wheels + run: | + ls dist + + - name: Get Release with tag + id: get_current_release + uses: joutvhu/get-release@v1 + with: + tag_name: ${{ inputs.release-version }} + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Upload Release Asset + id: upload_release_asset + if: inputs.upload-to-release + uses: actions/upload-release-asset@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + upload_url: ${{ steps.get_current_release.outputs.upload_url }} + asset_path: ./dist/${{env.wheel_name}} + asset_name: ${{env.wheel_name}} + asset_content_type: application/* diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 8d2ea71e4df..82d0d8d440b 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -13,200 +13,95 @@ on: - v* jobs: - setup_release: name: Create Release runs-on: ubuntu-latest + outputs: + release-version: ${{ steps.extract_branch.outputs.branch }} steps: - name: Get the tag version id: extract_branch - run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} + run: echo "branch=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT shell: bash - - name: Create Release - id: create_release - uses: actions/create-release@v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - tag_name: ${{ steps.extract_branch.outputs.branch }} - release_name: ${{ steps.extract_branch.outputs.branch }} + run: gh release create ${{ steps.extract_branch.outputs.branch }} --repo $GITHUB_REPOSITORY --title ${{ steps.extract_branch.outputs.branch }} --generate-notes + shell: bash build_wheels: name: Build Wheel needs: setup_release - runs-on: ${{ matrix.os }} - strategy: fail-fast: false matrix: - # Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the - # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. - os: [ubuntu-22.04] - python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] - torch-version: ['2.4.0', '2.5.1', '2.6.0', '2.7.1', '2.8.0'] - cuda-version: ['12.9.1'] - # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. - # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. - # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) - # when building without C++11 ABI and using it on nvcr images. - cxx11_abi: ['FALSE', 'TRUE'] - exclude: - # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix - # Pytorch < 2.5 does not support Python 3.13 - - torch-version: '2.4.0' - python-version: '3.13' - - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - - name: Set CUDA and PyTorch versions - run: | - echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV - echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV - echo "WHEEL_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV - echo "MATRIX_PYTHON_VERSION=$(echo ${{ matrix.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV - - - name: Free up disk space - if: ${{ runner.os == 'Linux' }} - # https://github.com/easimon/maximize-build-space/blob/master/action.yml - # https://github.com/easimon/maximize-build-space/tree/test-report - run: | - sudo rm -rf /usr/share/dotnet - sudo rm -rf /opt/ghc - sudo rm -rf /opt/hostedtoolcache/CodeQL - - - name: Set up swap space - if: runner.os == 'Linux' - uses: pierotofy/set-swap-space@v1.0 - with: - swap-size-gb: 10 - - - name: Install CUDA ${{ matrix.cuda-version }} - if: ${{ matrix.cuda-version != 'cpu' }} - uses: Jimver/cuda-toolkit@v0.2.26 - id: cuda-toolkit - with: - cuda: ${{ matrix.cuda-version }} - linux-local-args: '["--toolkit"]' - # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1 - # method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }} - method: 'network' - sub-packages: '["nvcc"]' - - - name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }} - run: | - pip install --upgrade pip - # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error - # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable - pip install typing-extensions==4.12.2 - # We want to figure out the CUDA version to download pytorch - # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 - # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix - # This code is ugly, maybe there's a better way to do this. - export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ - minv = {'2.4': 118, '2.5': 118, '2.6': 118, '2.7': 118, '2.8': 126}[env['MATRIX_TORCH_VERSION']]; \ - maxv = {'2.4': 124, '2.5': 124, '2.6': 126, '2.7': 128, '2.8': 129}[env['MATRIX_TORCH_VERSION']]; \ - print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ - ) - if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then - # pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} - # Can't use --no-deps because we need cudnn etc. - # Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001 - pip install jinja2 - pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl - pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl - else - pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} - fi - nvcc --version - python --version - python -c "import torch; print('PyTorch:', torch.__version__)" - python -c "import torch; print('CUDA:', torch.version.cuda)" - python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" - shell: - bash - - - name: Build wheel - run: | - # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 - # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 - # However this still fails so I'm using a newer version of setuptools - pip install setuptools==75.8.0 - pip install ninja packaging wheel - export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH - export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH - # Limit MAX_JOBS otherwise the github runner goes OOM - # nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM - MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "129" ] && echo 1 || echo 2) NVCC_THREADS=2 FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist - tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }} - wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") - ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} - echo "wheel_name=${wheel_name}" >> $GITHUB_ENV - - - name: Log Built Wheels - run: | - ls dist - - - name: Get the tag version - id: extract_branch - run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} - - - name: Get Release with tag - id: get_current_release - uses: joutvhu/get-release@v1 - with: - tag_name: ${{ steps.extract_branch.outputs.branch }} - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - - name: Upload Release Asset - id: upload_release_asset - uses: actions/upload-release-asset@v1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.get_current_release.outputs.upload_url }} - asset_path: ./dist/${{env.wheel_name}} - asset_name: ${{env.wheel_name}} - asset_content_type: application/* - - publish_package: - name: Publish package - needs: [build_wheels] - - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v4 - - - uses: actions/setup-python@v5 - with: - python-version: '3.10' - - - name: Install dependencies - run: | - pip install ninja packaging wheel twine - # Install latest setuptools with support for pypi metadata 2.2 (improved compat w/ uv) - pip install setuptools==75.8.0 - # We don't want to download anything CUDA-related here - pip install torch --index-url https://download.pytorch.org/whl/cpu - - - name: Build core package - env: - FLASH_ATTENTION_SKIP_CUDA_BUILD: "TRUE" - run: | - python setup.py sdist --dist-dir=dist - - - name: Deploy - env: - TWINE_USERNAME: "__token__" - TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} - run: | - python -m twine upload dist/* + # Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the + # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. + os: [ubuntu-22.04] + python-version: ["3.13"] + torch-version: ["2.11.0"] + cuda-version: ["12.9.1", "13.0.1"] + # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. + # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. + # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) + # when building without C++11 ABI and using it on nvcr images. + cxx11_abi: ["TRUE"] + exclude: + # CUDA 13.0 is only supported by PyTorch 2.9+ + - torch-version: "2.6.0" + cuda-version: "13.0.1" + - torch-version: "2.7.1" + cuda-version: "13.0.1" + - torch-version: "2.8.0" + cuda-version: "13.0.1" + # No aarch64 PyTorch wheels for 2.6.0 + - torch-version: "2.6.0" + os: ubuntu-22.04-arm + # PyTorch 2.7+ pip wheels use CXX11_ABI=1 by default, no need for FALSE + - torch-version: "2.7.1" + cxx11_abi: "FALSE" + - torch-version: "2.8.0" + cxx11_abi: "FALSE" + - torch-version: "2.9.1" + cxx11_abi: "FALSE" + - torch-version: "2.10.0" + cxx11_abi: "FALSE" + - torch-version: "2.11.0" + cxx11_abi: "FALSE" + uses: ./.github/workflows/_build.yml + with: + runs-on: ${{ matrix.os }} + python-version: ${{ matrix.python-version }} + cuda-version: ${{ matrix.cuda-version }} + torch-version: ${{ matrix.torch-version }} + cxx11_abi: ${{ matrix.cxx11_abi }} + release-version: ${{ needs.setup_release.outputs.release-version }} + upload-to-release: true + + # publish_package: + # name: Publish package + # needs: [build_wheels] + # runs-on: ubuntu-latest + # steps: + # - uses: actions/checkout@v5 + # - uses: actions/setup-python@v6 + # with: + # python-version: "3.10" + # - name: Install dependencies + # run: | + # pip install ninja packaging wheel twine + # # Install latest setuptools with support for pypi metadata 2.2 (improved compat w/ uv) + # pip install setuptools==75.8.0 + # # We don't want to download anything CUDA-related here + # pip install torch --index-url https://download.pytorch.org/whl/cpu + # - name: Build core package + # env: + # FLASH_ATTENTION_SKIP_CUDA_BUILD: "TRUE" + # run: | + # python setup.py sdist --dist-dir=dist + # - name: Deploy + # env: + # TWINE_USERNAME: "__token__" + # TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} + # run: | + # python -m twine upload dist/* diff --git a/csrc/flash_attn/src/alibi.h b/csrc/flash_attn/src/alibi.h index a65a5b3790a..dda4cb35dc0 100644 --- a/csrc/flash_attn/src/alibi.h +++ b/csrc/flash_attn/src/alibi.h @@ -38,14 +38,20 @@ struct Alibi { const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q); #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + tensor(make_coord(i, mi), make_coord(j, nj)) += ((col_idx == (col_idx_limit_right - 1)) ? 0 : alibi_slope); + } } } } @@ -62,7 +68,7 @@ struct Alibi { #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; - tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + tensor(make_coord(i, mi), make_coord(j, nj)) += (((row_idx + max_seqlen_k - max_seqlen_q - col_idx) == 0) ? 0 : alibi_slope); } } } diff --git a/csrc/flash_attn/src/mask.h b/csrc/flash_attn/src/mask.h index 544065d1de0..50cbdaf9e93 100644 --- a/csrc/flash_attn/src/mask.h +++ b/csrc/flash_attn/src/mask.h @@ -140,7 +140,7 @@ struct Mask { // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tensor_.layout())); // Do we need both row and column indices, or just column incides? - static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; + static constexpr bool Col_idx_only = !Has_alibi && !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; if constexpr (Col_idx_only) { @@ -179,9 +179,10 @@ struct Mask { const int col_idx = col_idx_base + j; if constexpr (Has_alibi) { if constexpr (Is_causal) { - tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx; + tensor(make_coord(i, mi), make_coord(j, nj)) += ((col_idx == (col_idx_limit_right - 1)) ? 0 : alibi_slope); + } else { - tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + tensor(make_coord(i, mi), make_coord(j, nj)) += (((row_idx + max_seqlen_k - max_seqlen_q - col_idx) == 0) ? 0 : alibi_slope); } }