Skip to content

Commit

Permalink
[CI] Don't include <ATen/cuda/CUDAGraphsUtils.cuh>
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Dec 7, 2024
1 parent e782d28 commit 9375ac9
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
#include "philox_unpack.cuh" // For at::cuda::philox::unpack

#include <cutlass/numeric_types.h>

Expand Down
2 changes: 1 addition & 1 deletion csrc/flash_attn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#pragma once

#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
#include "philox_unpack.cuh" // For at::cuda::philox::unpack

#include <cute/tensor.hpp>

Expand Down
4 changes: 4 additions & 0 deletions csrc/flash_attn/src/philox_unpack.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
// This is purely so that it works with torch 2.1. For torch 2.2+ we can include ATen/cuda/PhiloxUtils.cuh

#pragma once
#include <ATen/cuda/detail/UnpackRaw.cuh>
2 changes: 1 addition & 1 deletion flash_attn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.7.1.post3"
__version__ = "2.7.1.post4"

from flash_attn.flash_attn_interface import (
flash_attn_func,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def check_if_rocm_home_none(global_option: str) -> None:


def append_nvcc_threads(nvcc_extra_args):
nvcc_threads = os.getenv("NVCC_THREADS") or "4"
nvcc_threads = os.getenv("NVCC_THREADS") or "2"
return nvcc_extra_args + ["--threads", nvcc_threads]


Expand Down

0 comments on commit 9375ac9

Please sign in to comment.