Skip to content

Commit

Permalink
Rewrite benchmarks to be more elapsed_time friendly (#2186)
Browse files Browse the repository at this point in the history
Performance with the current approach remains
[unchanged](https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/10794302056/job/29938275434),
but greatly improves the numbers in a situation where `elapsed_time`
method is used.

Part of #2149

Closes #2198

Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev authored Sep 11, 2024
1 parent b2e18b2 commit 270c13a
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,16 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider):

elif provider == 'xetla':
func = getattr(xetla_kernel, 'flash_attn')
xetla_fn = lambda: func(Z, H, D_HEAD, N_CTX, N_CTX)
out = torch.empty_like(q, device='xpu', dtype=dtype)
size_score = Z * H * N_CTX * N_CTX
size_attn_mask = Z * N_CTX * N_CTX
dropout_mask = torch.empty((size_score, ), device='xpu', dtype=torch.uint8)
bias = torch.empty((size_attn_mask, ), device='xpu', dtype=dtype)
size_ml = Z * H * N_CTX
m = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)

xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)

Expand Down
10 changes: 5 additions & 5 deletions benchmarks/triton_kernels_benchmark/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,14 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n
MAX_WORK_GROUP_SIZE = properties["max_work_group_size"]


def softmax(x):
def softmax(x, y):
n_rows, n_cols = x.shape

# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE_X = triton.next_power_of_2(n_cols)
BLOCK_SIZE_Y = MAX_WORK_GROUP_SIZE // BLOCK_SIZE_X
BLOCK_SIZE_Y = BLOCK_SIZE_Y if BLOCK_SIZE_Y > 0 else 1

# Allocate output
y = torch.empty_like(x)
# Create a number of persistent programs.
softmax_kernel[(n_rows // BLOCK_SIZE_Y, )](y, x, x.stride(0), y.stride(0), n_cols, BLOCK_SIZE_X=BLOCK_SIZE_X,
BLOCK_SIZE_Y=BLOCK_SIZE_Y)
Expand Down Expand Up @@ -133,7 +131,8 @@ def benchmark(M, N, provider):
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles,
warmup=10, rep=10)
if provider == "triton":
triton_fn = lambda: softmax(x)
out = torch.empty_like(x, device="xpu")
triton_fn = lambda: softmax(x, out)
torch_fn = lambda: torch.softmax(x, axis=-1)
benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch")
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10)
Expand All @@ -145,7 +144,8 @@ def benchmark(M, N, provider):
elif provider == "xetla":
name = f"softmax_shape_{M}_{N}"
func = getattr(xetla_kernel, name)
xetla_fn = lambda: func(x, 0)
out = torch.empty_like(x, device="xpu")
xetla_fn = lambda: func(x, out, 0)
torch_fn = lambda: torch.softmax(x, axis=-1)
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), err_msg="xetla to torch")
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10)
Expand Down
41 changes: 26 additions & 15 deletions benchmarks/xetla_kernel/flash_attention/fmha_forward_v5.h
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,9 @@ class FmhaForwardKernel;
// The launcher of fmha forward kernel
template <typename fmha_policy, typename T, bool kUseBias = false,
bool kIsCausal = false, bool kIsTraining = false>
sycl::event fmha_forward_impl(sycl::queue &q, uint32_t num_batches,
sycl::event fmha_forward_impl(sycl::queue &q, void *_q, void *_k, void *_v,
void *_out, void *_dropout_mask, void *_bias,
void *_m, void *_l, uint32_t num_batches,
uint32_t num_heads, uint32_t head_size,
uint32_t num_queries, uint32_t num_keys,
uint64_t seed = 0, uint64_t offset = 123) {
Expand All @@ -642,14 +644,23 @@ sycl::event fmha_forward_impl(sycl::queue &q, uint32_t num_batches,
uint32_t size_ml = shape.get_ml_size();

// forward
T *query = sycl::malloc_shared<T>(size_query, q);
T *key = sycl::malloc_shared<T>(size_key, q);
T *value = sycl::malloc_shared<T>(size_key, q);
T *bias = sycl::malloc_shared<T>(size_attn_mask, q);
uint8_t *dropout_mask = sycl::malloc_shared<uint8_t>(size_score, q);
T *out = sycl::malloc_shared<T>(size_query, q);
float *m = sycl::malloc_shared<float>(size_ml, q);
float *l = sycl::malloc_shared<float>(size_ml, q);
// T *query = sycl::malloc_shared<T>(size_query, q);
// T *key = sycl::malloc_shared<T>(size_key, q);
// T *value = sycl::malloc_shared<T>(size_key, q);
T *query = static_cast<T *>(_q);
T *key = static_cast<T *>(_k);
T *value = static_cast<T *>(_v);

// T *bias = sycl::malloc_shared<T>(size_attn_mask, q);
T *bias = static_cast<T *>(_bias);
// uint8_t *dropout_mask = sycl::malloc_shared<uint8_t>(size_score, q);
uint8_t *dropout_mask = static_cast<uint8_t *>(_dropout_mask);
// T *out = sycl::malloc_shared<T>(size_query, q);
T *out = static_cast<T *>(_out);
// float *m = sycl::malloc_shared<float>(size_ml, q);
float *m = static_cast<float *>(_m);
// float *l = sycl::malloc_shared<float>(size_ml, q);
float *l = static_cast<float *>(_l);

// fmha forward kernel
using fmha_forward_op_t =
Expand All @@ -676,12 +687,12 @@ sycl::event fmha_forward_impl(sycl::queue &q, uint32_t num_batches,
fmha_fwd_op(ei, args);
});
});
sycl::free(query, q);
sycl::free(key, q);
sycl::free(value, q);
sycl::free(bias, q);
sycl::free(dropout_mask, q);
sycl::free(out, q);
// sycl::free(query, q);
// sycl::free(key, q);
// sycl::free(value, q);
// sycl::free(bias, q);
// sycl::free(dropout_mask, q);
// sycl::free(out, q);
return event;
}

Expand Down
28 changes: 21 additions & 7 deletions benchmarks/xetla_kernel/python_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ sycl::queue get_current_sycl_queue() {
CHECK_CONTIGUOUS(x)

template <typename T>
at::Tensor softmax(const at::Tensor &input, const int64_t dim) {
at::Tensor softmax(const at::Tensor &input, const at::Tensor &output,
const int64_t dim) {
CHECK_INPUT(input);
CHECK_INPUT(output);
RECORD_FUNCTION("xetla softmax", {input});

auto output = at::empty_like(input);

auto queue = get_current_sycl_queue();
auto evt = softmax_forward<T>(input.data_ptr(), output.data_ptr(), queue);
xpu::profiler_record("xetla kernel", evt);
Expand Down Expand Up @@ -72,13 +72,27 @@ at::Tensor bf16_stream_k_gemm(const at::Tensor &a, const at::Tensor &b,

#define CALL_IMPL_ATTENTION_FUNC(P) \
fmha::fmha_forward_impl<P, T, use_mask, IsCausal, use_dropout>( \
queue, num_batches, num_heads, head_size, num_queries, num_keys)
queue, q.data_ptr(), k.data_ptr(), v.data_ptr(), out.data_ptr(), \
dropout_mask.data_ptr(), bias.data_ptr(), m.data_ptr(), l.data_ptr(), \
num_batches, num_heads, head_size, num_queries, num_keys)

template <bool use_mask = false, bool IsCausal = false,
bool use_dropout = false>
void flash_attn(const int64_t num_batches, const int64_t num_heads,
const int64_t head_size, const int64_t num_queries,
const int64_t num_keys) {
void flash_attn(const at::Tensor &q, const at::Tensor &k, const at::Tensor &v,
const at::Tensor &out, const at::Tensor &dropout_mask,
const at::Tensor &bias, const at::Tensor &m,
const at::Tensor &l, const int64_t num_batches,
const int64_t num_heads, const int64_t head_size,
const int64_t num_queries, const int64_t num_keys) {

CHECK_INPUT(q);
CHECK_INPUT(k);
CHECK_INPUT(v);
CHECK_INPUT(out);
CHECK_INPUT(dropout_mask);
CHECK_INPUT(bias);
CHECK_INPUT(m);
CHECK_INPUT(l);
RECORD_FUNCTION("xetla fa",
{num_batches, num_heads, head_size, num_queries, num_keys});

Expand Down

0 comments on commit 270c13a

Please sign in to comment.