Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use order from A matrix when determining DPAS layout #2834

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

alexbaden
Copy link
Contributor

One of the learnings from mapping oneDNN kernels to Triton layouts is that the warpsPerCTA attribute for the DPAS layout should be modified to match the A matrix configuration. If A is a row major matrix, warpsPerCTA should bias towards narrow blocks, presumably because narrow blocks better match the shape of the input matrix. If A is column major, warpsPerCTA should bias towards a tall matrix, because the fast-changing dimension of A is now the columns. For the GEMM case, this change has the effect of interspersing DPAS instructions and shuffles, which I believe is reducing memory pressure and resulting in better performance for AxBT. AxB is unchanged, presumably because the loads are quite efficient with minimal shuffles needed.

To allow for the A matrix layout to affect the choice of DPAS layout I needed to have the A matrix layout properly convey the A matrix order. This info was recently introduced via AxisInfoAnalysis run during the Coalesce pass. Following the upstream pipeline convention, I moved Coalesce to the top of the TTGIR optimization pipeline which properly tags all blocked layouts with the correct order. Then we use the order for the A matrix when determining the row and column dimensions when mapping warps to tiles.

I plan on modifying the tile sizes for AxB.T (and the A.T matrices) but wanted to split this change out as it is relatively compact but does modify the pass pipeline.

GEMM Performance with this change:

Compute A x B
✅ Triton and Torch match
AxB             torch                   triton
gbps mean       213.63573793442782      182.23695062827258
gbps min        181.8148801516574       162.30303718530993
gbps max        238.444109847901        193.78924336325545
tflops mean     132.58363372415397      113.09735602627339
tflops min      112.83541653048312      100.72624853197415
tflops max      147.97985968742458      120.26677891149912
Compute A x B.T
✅ Triton and Torch match
AxBT            torch                   triton
gbps mean       206.72928493799927      132.31416040867134
gbps min        176.25819312358402      109.0614239825154
gbps max        227.53161576706216      138.1909227807394
tflops mean     128.2974471372795       82.1149698536239
tflops min      109.3869028839697       67.6841807018762
tflops max      141.20749972452825      85.762124198471
Compute A.T x B
✅ Triton and Torch match
ATxB            torch                   triton
gbps mean       209.2849587186014       42.61530011404569
gbps min        170.08950422969238      41.92474555878943
gbps max        231.30353066510696      43.27539742013329
tflops mean     129.8835137744532       26.447313525322894
tflops min      105.55857717042724      26.01875118315174
tflops max      143.54837297034516      26.856973914070593
Compute A.T x B.T
✅ Triton and Torch match
ATxBT           torch                   triton
gbps mean       154.956421026935        34.898164836367926
gbps min        137.61935080756382      34.43770830725105
gbps max        164.27558049997361      35.10572250470419
tflops mean     96.16689401914024       21.658012601479243
tflops min      85.40740316784566       21.372250488863674
tflops max      101.95042086786239      21.78682414837399

Performance from main:

Compute A x B
✅ Triton and Torch match
AxB             torch                   triton
gbps mean       214.87746885154405      180.0439531109416
gbps min        178.36601535763648      153.21912702731154
gbps max        235.07477739243978      190.0428815826552
tflops mean     133.35425945695823      111.73636847612374
tflops min      110.69503013710286      95.0887188339194
tflops max      145.88883154536867      117.9417640852357
Compute A x B.T
✅ Triton and Torch match
AxBT            torch                   triton
gbps mean       206.68481224002502      104.29780682676666
gbps min        173.36176681958614      89.18300767881824
gbps max        226.93472316979117      105.54846253882953
tflops mean     128.2698471113852       64.72785102461154
tflops min      107.58936316560981      55.34751506855143
tflops max      140.83706456113097      65.50401553924934
Compute A.T x B
✅ Triton and Torch match
ATxB            torch                   triton
gbps mean       215.89206903981136      42.88010635752647
gbps min        177.41492802824132      42.44726352124893
gbps max        231.5511693295856       43.23214291842088
tflops mean     133.98392648288896      26.611653884913395
tflops min      110.1047795763146       26.343028997429634
tflops max      143.70205902636098      26.830129908159375
Compute A.T x B.T
✅ Triton and Torch match
ATxBT           torch                   triton
gbps mean       153.0380667319401       34.706542923002324
gbps min        132.35544226153112      34.58087504086132
gbps max        161.9991050835991       34.84553202358425
tflops mean     94.97635171727674       21.53909088069962
tflops min      82.14058962170172       21.461100631419388
tflops max      100.53762642763967      21.625348358878945

@alexbaden
Copy link
Contributor Author

alexbaden commented Dec 6, 2024

The latest changes resolve performance regressions with gemm-preop-exp. I am looking into issues with the following:

@Jianhui-Li
Copy link

I don't get why just look at the ordering attribute (layout) of matrix A to determine the DPAS layout. I thought we should look at both matrix A and B's layout and decide.

alexbaden added a commit that referenced this pull request Dec 9, 2024
…2956)

Required for #2834 

Two reasons to do this - one, it properly tags the layouts with their
memory order very early in the TTGIR pipeline. And two, it moves our
TTGIR pipeline closer to upstream. I am splitting the change to isolate
any regressions or undesired behavior caused by this change vs changing
the DPAS layouts in #2834.

cc #2354
@alexbaden alexbaden force-pushed the alex/transpose_warps_per_cta branch from 976c1a1 to 0ecff19 Compare December 13, 2024 14:20
@alexbaden
Copy link
Contributor Author

When we determined that oneDNN used a different layout for AxBT I started doing some testing and quickly realized that the warpsPerCTA parameter makes very little difference for AxB. However, for AxBT or for gemm kernels that do computation on the output of AxB (like AxB+D) the warpsPerCTA parameter does make a difference. I suspect that a wider subgroup block mapping inside a workgroup is beneficial when kernel time is driven by B loads or later computation. Having more A matrix rows within a work-group should improve coalesced memory accesses and cache re-use. If A is transposed, then transposing warpsPerCTA accomplishes the same thing (more columns promotes memory coalescing and cache re-use).

@alexbaden alexbaden force-pushed the alex/transpose_warps_per_cta branch from 0ecff19 to 7eee7e3 Compare December 18, 2024 14:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants