Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use 2D block loads for post-DPAS chained ops #3000

Merged
merged 4 commits into from
Dec 18, 2024
Merged

Conversation

alexbaden
Copy link
Contributor

@alexbaden alexbaden commented Dec 12, 2024

This is a proof of concept to use the 2D block load for tensor pointer loads where the layout is a DPAS layout but the result of the load is not directly used in the DPAS computation. I built the PoC arond the gemm_postop_addmatrix_benchmark kernel which computes AxB + D. The D matrix uses a block pointer load and the TTGIR applies the DPAS MMA layout to the D matrix load:

    %16 = tt.make_tensor_ptr %arg3, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %12] {order = array<i32: 1, 0>} : <tensor<256x256xf32, #mma>> loc(#loc21)
    %17 = tt.load %16 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x256xf32, #mma>> loc(#loc22)
    %18 = arith.addf %15#0, %17 : tensor<256x256xf32, #mma> loc(#loc23)
    %19 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %12] {order = array<i32: 1, 0>} : <tensor<256x256xf32, #mma>> loc(#loc24)
    tt.store %19, %18 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf32, #mma>> loc(#loc25)

But, when this code is lowered currently we use a scalar load. This requires extracting all the scalar values from the dpas registers, loading each scalar from a tensor of pointers for the D matrix, and computing the scalar add. A representative GEMM kernel for AxB is approximately 1100 instructions in ASM, but with scalar D matrix addition that increases to 7500 instructions and is quite slow. With this PR, the ASM is down to 1500 and performance is much better. I believe we both increase memory bandwidth by doing larger loads and decrease register pressure by using the 2D block load data directly from registers. I modified the GEMM with block pointers tutorial to do the matrix addition and compared directly with PyTorch.

main:

Compute A x B
✅ Triton and Torch match
AxB      	torch			triton
gbps mean	156.7301576517312	138.11238212309073
gbps min	141.30597513931122	125.77424121785864
gbps max	165.34311994793367	142.42265041843044
tflops mean	97.2676857184077	85.71338139033023
tflops min	87.69534457130585	78.05625636793165
tflops max	102.61294231920246	88.38836001725622

w/ this change:

Compute A x B
✅ Triton and Torch match
AxB             torch                   triton
gbps mean       157.91230857925294      157.63376432129215
gbps min        140.6169072268218       143.2243668186498
gbps max        167.52037117445636      163.5302792054239
tflops mean     98.00133574857878       97.82846949394128
tflops min      87.26770484864576       88.88591007411962
tflops max      103.96415762584442      101.48788236748732

note however that our GEMM performance is not as good as PyTorch/oneDNN:

Compute A x B
✅ Triton and Torch match
AxB             torch                   triton
gbps mean       223.0701969650354       181.18122189537945
gbps min        187.16468723584453      162.91434700909443
gbps max        238.444109847901        192.58129503451758
tflops mean     138.43871617708862      112.44216437628396
tflops min      116.15553923000287      101.10563111352282
tflops max      147.97985968742458      119.51711885778543

But now with the matrix addition, we're even with / sometimes slightly ahead of PyTorch. This demonstrates the benefit of operator fusion in Triton and if we can improve gemm performance a bit we should be able to pull ahead of PyTorch for this use case.

The code here contains a lot of duplication because I was not sure what parts of the existing block load code would be relevant, and what could be ignored. I plan to clean up the duplication (and I expect a few test failures that I will have to fix) before this is ready to be moved out of draft.

cc #2997

@alexbaden
Copy link
Contributor Author

This also resolves the performance regression in #2834 for the add matrix case. We get even to slightly-higher performance when we use a transposed warpsPerCTA order:
Standard warpsPerCTA Order:

Compute A x B
✅ Triton and Torch match
AxB             torch                   triton
gbps mean       205.80269703051565      183.18315754413806
gbps min        179.1045897088753       162.24215503026394
gbps max        226.5780973097816       192.58129503451758
tflops mean     127.7224010662109       113.68457777284686
tflops min      111.15339385568986      100.68846469756984
tflops max      140.6157403910402       119.51711885778543
Compute A x B + D
✅ Triton and Torch match
AxB             torch                   triton
gbps mean       157.65789768085992      158.24855403301189
gbps min        142.79881082578163      141.67625237645345
gbps max        168.10632973298533      163.16016456641276
tflops mean     97.84344680315183       98.2100117150328
tflops min      88.6218074458184        87.9251408687808
tflops max      104.32780705853148      101.25818697939796

Transposed warpsPerCTA order:

Compute A x B
✅ Triton and Torch match
AxB             torch                   triton
gbps mean       220.02125371014108      189.8327155522294
gbps min        196.16217040160757      163.2217377764552
gbps max        234.8195487001742       196.96612401967587
tflops mean     136.54652351465722      117.81133377301992
tflops min      121.73943181287643      101.29639968672126
tflops max      145.7304350721081       122.23837030069579
Compute A x B + D
✅ Triton and Torch match
AxB             torch                   triton
gbps mean       161.31998225073477      166.80650419582096
gbps min        142.56348466801188      149.7706365256271
gbps max        168.8940264630101       176.18639363624158
tflops mean     100.1161586816681       103.52112745243674
tflops min      88.47576260608734       92.94856472863161
tflops max      104.81665642310443      109.34234368697658

@alexbaden alexbaden force-pushed the alex/block_adds branch 2 times, most recently from 5532f2a to a312d98 Compare December 13, 2024 01:14
@alexbaden alexbaden marked this pull request as ready for review December 13, 2024 02:03
@alexbaden
Copy link
Contributor Author

I have been trying to break it today and have not uncovered any issues. I cleaned up a good chunk of the duplication and have some ideas for how to further abstract away the differences but I'm going to save that for another PR. This is now ready for review.

@alexbaden alexbaden changed the title [PoC]: Use 2D block loads for post-DPAS chained ops Use 2D block loads for post-DPAS chained ops Dec 13, 2024
Copy link
Contributor

@jopperm jopperm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat!

@alexbaden
Copy link
Contributor Author

I modified the addmatrix benchmarks to include int8 to get some idea of correctness and performance there. The result is correct, and I see block loads in the IR. Performance is better:

block adds:
matmul-performance-postop-addmatrix:
   B     M     K     N           dtype  Triton-GB/s  Triton-GB/s-min  Triton-GB/s-max  Triton-TFlops  Triton-TFlops-min  Triton-TFlops-max  Triton-CV
0  1  1024  1024  1024  torch.bfloat16   209.401106       113.975650       210.557421      53.606683          29.177766          53.902700   0.002067
1  1  1024  1024  1024      torch.int8   236.698860       234.057137       239.400923      60.594908          59.918627          61.286636   0.006385
2  1  2048  2048  2048  torch.bfloat16   187.350255       160.332722       206.007068      95.923331          82.090354         105.475619   0.064107
3  1  2048  2048  2048      torch.int8   222.953037       219.367361       225.500214     114.151955         112.316089         115.456110   0.006912
4  1  4096  4096  4096  torch.bfloat16   164.567223       152.215714       168.344529     168.516837         155.868891         172.384797   0.029105
5  1  4096  4096  4096      torch.int8   164.680295       159.358059       169.227511     168.632622         163.182653         173.288971   0.013594
6  1  8192  8192  8192  torch.bfloat16   103.265063        95.490571       105.903398     211.486848         195.564689         216.890160   0.019848
7  1  8192  8192  8192      torch.int8    88.904751        87.769900        89.910054     182.076931         179.752756         184.135790   0.007747

main: 
matmul-performance-postop-addmatrix:
   B     M     K     N           dtype  Triton-GB/s  Triton-GB/s-min  Triton-GB/s-max  Triton-TFlops  Triton-TFlops-min  Triton-TFlops-max  Triton-CV
0  1  1024  1024  1024  torch.bfloat16   220.289058        76.875076       224.054711      56.393999          19.680019          57.358006   0.008403
1  1  1024  1024  1024      torch.int8   215.423907       213.995100       217.546892      55.148520          54.782745          55.692004   0.003764
2  1  2048  2048  2048  torch.bfloat16   175.145800       162.318271       182.361043      89.674650          83.106955          93.368854   0.021099
3  1  2048  2048  2048      torch.int8   186.579361       185.260784       187.748615      95.528633          94.853522          96.127291   0.001961
4  1  4096  4096  4096  torch.bfloat16   137.149996       121.170133       149.369803     140.441596         124.078216         152.954678   0.064511
5  1  4096  4096  4096      torch.int8   135.321954       123.090362       136.024125     138.569681         126.044530         139.288704   0.006765
6  1  8192  8192  8192  torch.bfloat16    88.742803        87.917075        89.290380     181.745260         180.054169         182.866699   0.002579
7  1  8192  8192  8192      torch.int8    71.957158        69.739437        73.489197     147.368260         142.826368         150.505874   0.010592

I will PR the benchmark changes separately - torch does not support int8 mma on the GPU so we have to go back to CPU, which makes the benchmark take about 20x longer. We will have to decide how to handle that.

? cast<DpasEncodingAttr>(encoding)
: cast<DpasEncodingAttr>(
getDotEncoding(tensorType).value().getParent());
auto dotOrder = dpasLayout.getThreadOrder();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need get the order from the layout encoding instead of the parent layout for #ttg.dot_op.
Maybe we can use encoding.getThreadOrder.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, now that you point it out I do not think it was a good idea to combine the dpasLayout and dotLayout into the same dpasLayout variable. I separated them again since the dotLayout is only used in the code block after the conditional for the dpasLayout load.

@alexbaden alexbaden enabled auto-merge (squash) December 18, 2024 03:09
@alexbaden alexbaden merged commit c4201fa into main Dec 18, 2024
5 checks passed
@alexbaden alexbaden deleted the alex/block_adds branch December 18, 2024 04:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants