Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOCUMENTS]Update the DPAS encoding documents. #2746

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,52 +14,168 @@ def DpasEncodingAttr : DistributedEncoding<"DpasEncoding", "intel_dpas_encoding"
let mnemonic = "dpas";

let description = [{
An encoding for the tensors distributed across the threads for the C and D operands of XMX tensor core operation.
An encoding for the tensors distributed across the threads for the C and D operands of XMX tensor core operation
and its corresponding A and B operands layout with the DPAS encoding as parent.
The XMX tensor core operation is defined for matrix matmul as: D=A*B+C
The shape of the of XMX tensor core operation is defined by systolic depth, repeat count, execution size and operations per channel.

The encoding is characterized by parameters:
- `repeatCount` which shall be in the range [1, 8]
- `systolicDepth` For PVC/ATSM, the size is 8.
- `executionSize` For PVC, the size is 16. For ATSM, the size is 8.
- `opsPerChannel` 4 for 8 bit scalar type, 2 for 16 bit scalar type, 1 for 32 bit scalar type.
- `warpsPerCTA`
- `sugGroupSize` valid sub group size is 8/16/32


The layout example repeat_count=8, systolic_depth=8, execution_size=16 and operands_per_chan=2 for warp size 32.
For A operand:
systolic depth = 8
<------------------------------------------------------------------------------------------------->
opsPerChan=2
<--------->
t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 ^
t8 ... t8 t9 ... t9 t10 ... t10 t11 ... t11 t12 ... t12 t13 ... t13 t14 ... t14 t15 ... t15 |
t16 ... t16 t17 ... t17 t18 ... t18 t19 ... t19 t20 ... t20 t21 ... t21 t22 ... t22 t23 ... t23 |
t24 ... t24 t25 ... t25 t26 ... t26 t27 ... t27 t28 ... t28 t29 ... t29 t30 ... t30 t31 ... t31 | repeat count <= 8
t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 |
t8 ... t8 t9 ... t9 t10 ... t10 t11 ... t11 t12 ... t12 t13 ... t13 t14 ... t14 t15 ... t15 |
t16 ... t16 t17 ... t17 t18 ... t18 t19 ... t19 t20 ... t20 t21 ... t21 t22 ... t22 t23 ... t23 |
t24 ... t24 t25 ... t25 t26 ... t26 t27 ... t27 t28 ... t28 t29 ... t29 t30 ... t30 t31 ... t31 v

For B operand:
execution size = 16
<------------------------------------------------------------->
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^ ^
. . . . . . . . . . . . . . . . | opsPerChan=2|
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v |
t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 |
. . . . . . . . . . . . . . . . |
t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 | systolic depth = 8
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
. . . . . . . . . . . . . . . . |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 |
. . . . . . . . . . . . . . . . |
t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 v

This pattern repeats every warpsPerTile[0] (resp. warpsPerTile[1]) blocks
along the row (resp. col) dimension.
- `opsPerChannel` 4 for 8 bit scalar type of A/B operands of DPAS instruction,
2 for 16 bit scalar type of A/B operands of DPAS instruction,
1 for 32 bit scalar type of A/B operands of DPAS instruction.
- `warpsPerCTA` indicates the distribution of the warps in the block. The order is [1, 0] for rank 2.
LiyangLingIntel marked this conversation as resolved.
Show resolved Hide resolved
- `repCluster` indicates the cluster size of the repetitions of the DPAS tile.
- `threadsPerWarp_` AKA threadsPerWarp, use the name threadsPerWarp_ to avoid conflicting
with the `getThreadsPerWarp` in interface DistributedLayout. Currently only 16 is supported.

The values of the matrix is distributed across the threads in the subgroup as row-major order.
- If the column size of the matrix is equal to the number of threads in the subgroup, one scalar represents one row of the matrix in register.
- If the column size of the matrix is less than the number of threads in the subgroup, one scalar represents multiple rows of the matrix in register.
- If the column size of the matrix is larger than the number of the threads in the subgroup, one scalar represents partial row of the matrix in register.

Example 1, the column size of the matrix is 16 and the number of threads in the subgroup is 16.
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=2 and threadsPerWarp=16.

The layout for A operand:
K = 16 (K = systolic depth * opsPerChan)
<---------------------------------------------------------------------------->

t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | M = 8 (repeat count)
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v

The layout for B operand:
N = 16 (N = execution size)
<---------------------------------------------------------------------------->

t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | K = 16 (K = systolic depth * opsPerChan)
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v

The layout for C operand and result D:
N = 16 (N = execution size)
<---------------------------------------------------------------------------->
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | M = 8 (M=repeat count)
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v

Example 2, the column size of the matrix is 8 and the number of threads in the subgroup is 16.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as my previous comment. I think this fits but please confirm:

A: tensor<8x8xf32>
B: tensor<8x16xf32>
D: tensor<8x8xf32>
dpasEncoding:  triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChannel = 1, threadsPerWarp = 16,  warpsPerCTA = [1,1] , repCluster = [1,1]}>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I update the example of the tt.dot operation. Please check.

The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=1 and threadsPerWarp=16.

The layout for A operand:
K = 8 (K = systolic depth * opsPerChan)
<---------------------------------------->

t0 t1 t2 t3 t4 t5 t6 t7 ^
t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 |
t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 | M = 8 (repeat count)
t8 t9 t10 t11 t12 t13 t14 t15 |
t0 t1 t2 t3 t4 t5 t6 t7 |
t8 t9 t10 t11 t12 t13 t14 t15 v

The layouts for B operand is like the one of opsPerChan=2 but the K size is 8.
The layouts for C and D operands are same as the one of opsPerChan=2.

Example 3, the column size of the matrix is 32 and the number of threads in the subgroup is 16.
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=4 and threadsPerWarp=16.

The layout for A operand:
K = 32 (K = systolic depth * opsPerChan)
<----------------------------------------------------------------------------------------------------------------------------------->

t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 ^
t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 |
t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 |
t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 |
t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | M = 8 (repeat count)
t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 |
t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 |
t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 v

The layouts for B operand is like the one of opsPerChan=2 but the K size is 32.
The layouts for C and D operands are same as the one of opsPerChan=2.

The patterns (illustrated above) repeats every warpsPerTile[0] (resp. warpsPerTile[1]) blocks
along the row (resp. col) dimension. And the repetitions are clustered of the size of repCluster to optimize the memory accessing.

Suppose we have a `tt.dot` operation of the block size [64, 128] = [64, 32] * [32, 128] of f16/bf16. And its input tensor layout is defined as follows:
```
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 2]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#dpas, kWidth=2}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#dpas, kWidth=2}>

%d = tt.dot %a, %b, %c : tensor<64x32xf16, #dot_operand_a> * tensor<32x128xf16, #dot_operand_b> -> tensor<64x128xf32, #dpas>
```
The semantic of this `tt.dot` includes GEMM tiling configuration as:

warp[:0] warp[:1] warp[:0] warp[:1]
|----^----|----^----|----^----|----^----|
repCluster[1]
<--------->
┌────┬────┬────┬────┬────┬────┬────┬────┐
│W0R0│W0R1│W1R0│W1R1│W0R4│W0R5│W1R4│W1R5│
│W2R0│W2R1│W3R0│W3R1│W2R4│W2R5│W3R4│W3R5│
warpPerCTA = [[W0, W1], ├────┼────┼────┼────┼────┼────┼────┼────┤
[W2, W3]] │W0R2│W0R3│W1R2│W1R3│W0R6│W0R7│W1R6│W1R7│
│W2R2│W2R3│W3R2│W3R3│W2R6│W2R7│W3R6│W3R7│
└────┴────┴────┴────┴────┴────┴────┴────┘


- ^ ┌────┬────┐ ┌────┬────┬────┬────┬────┬────┬────┬────┐
| | │W0R0│W0R2│ │W0R0│W0R1│W1R0│W1R1│W0R4│W0R5│W1R4│W1R5│
| | │W1R0│W1R2│ │ │ │ │ │ │ │ │ │
warp[0:] < repCluster[0] | ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| | │W0R1│W0R3│ │W0R2│W0R3│W1R2│W1R3│W0R6│W0R7│W1R6│W1R7│
| | │W1R1│W1R3│ │ │ │ │ │ │ │ │ │
- v ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │W2R0│W2R2│ │W2R0│W2R1│W3R0│W3R1│W2R4│W2R5│W3R4│W3R5│
| │W3R0│W3R2│ │ │ │ │ │ │ │ │ │
warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │W2R1│W2R1│ │W2R2│W2R3│W3R2│W3R3│W2R6│W2R7│W3R6│W3R7│
| │W3R1│W3R1│ │ │ │ │ │ │ │ │ │
- ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │W0R4│W0R6│ │W0R8│W0R9│W1R8│W1R9│W0 │W0 │W1 │W1 │
| │W1R4│W1R6│ │ │ │ │ │R12 │R13 │R12 │R13 │
warp[0:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │W0R5│W0R7│ │W0 │W0 │W1 │W1 │W0 │W0 │W1 │W1 │
| │W1R5│W1R7│ │R10 │R11 │R10 │R11 │R14 │R15 │R14 │R15 │
- ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │W2R4│W2R6│ │W2R8│W2R9│W3R8│W3R8│W2 │W2 │W3 │W3 │
| │W3R4│W3R6│ │ │ │ │ │R12 │R13 │R12 │R13 │
warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │W2R5│W2R7│ │W2 │W2 │W3 │W3 │W2 │W2 │W3 │W3 │
| │W3R5│W3R7│ │R10 │R11 │R10 │R10 │R14 │R15 │R14 │R15 │
- └────┴────┘ └────┴────┴────┴────┴────┴────┴────┴────┘


}];

let parameters = (
Expand All @@ -70,7 +186,7 @@ along the row (resp. col) dimension.
"unsigned":$opsPerChannel,
ArrayRefParameter<"unsigned">:$warpsPerCTA__,
ArrayRefParameter<"unsigned">:$repCluster,
"unsigned":$subGroupSize
"unsigned":$threadsPerWarp_
);

let extraClassDeclaration = extraDistributedDeclaration # [{
Expand Down
14 changes: 7 additions & 7 deletions third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ SmallVector<unsigned> DpasEncodingAttr::getShapeC() const {
SmallVector<unsigned> DpasEncodingAttr::getSizePerThread() const {
size_t rank = getWarpsPerCTA().size();
SmallVector<unsigned> res(rank, 1);
unsigned threadsPerWarp = getSubGroupSize();
unsigned threadsPerWarp = getThreadsPerWarp_();
auto shapeC = getDPASInstShapeC();
unsigned elemsNum = product<unsigned>(shapeC);
unsigned elemsPerThread = elemsNum / threadsPerWarp;
Expand Down Expand Up @@ -260,7 +260,7 @@ unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperand(
ArrayRef<int64_t> shape, mlir::Type eltTy, int kWidth, int opIdx) const {
auto shapePerCTA = getShapePerCTA(*this, shape);
auto rep = getDPASRepetitions(shapePerCTA, opIdx);
auto threadsPerWar = getSubGroupSize();
auto threadsPerWar = getThreadsPerWarp_();
size_t rank = shape.size();
if (opIdx == 0) {
auto shapeA = getShapeA();
Expand Down Expand Up @@ -296,7 +296,7 @@ SmallVector<unsigned> DpasEncodingAttr::getThreadsPerWarp() const {
size_t rank = getWarpsPerCTA().size();
SmallVector<unsigned> res(rank, 1);
auto executionSize = getExecutionSize();
auto subGroupSize = getSubGroupSize();
auto subGroupSize = getThreadsPerWarp_();
if (subGroupSize < executionSize) {
llvm::report_fatal_error("DpasEncodingAttr sub-group size could not be "
"smaller than the execution size");
Expand Down Expand Up @@ -340,7 +340,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const {
assert((rank == 2 || rank == 3) && "unexpected rank number for Dpas layout");
if (opIdx == 0) {
SmallVector<unsigned> shapeA = getDPASInstShapeA();
unsigned subGroupSize = getSubGroupSize();
unsigned subGroupSize = getThreadsPerWarp_();
unsigned opsPerChannel = getOpsPerChannel();

// pack the value to i16 for scalar bit width <=16.
Expand All @@ -359,7 +359,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const {

if (opIdx == 1) {
auto shapeB = getShapeB();
auto subGroupSize = getSubGroupSize();
auto subGroupSize = getThreadsPerWarp_();
auto executionSize = getExecutionSize();
if (subGroupSize < executionSize) {
llvm::report_fatal_error("DpasEncodingAttr sub-group size could not "
Expand Down Expand Up @@ -394,7 +394,7 @@ SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() {
assert(rank == 2 || rank == 3);
SmallVector<unsigned> contigPerThread(rank, 1);

unsigned threadsPerWarp = getSubGroupSize();
unsigned threadsPerWarp = getThreadsPerWarp_();
auto instShapeC = getDPASInstShapeC();
// The software vectorization vectorized the value as C array: int a[N] -> int
// a[N][threadsPerWarp]
Expand Down Expand Up @@ -506,7 +506,7 @@ void DpasEncodingAttr::print(AsmPrinter &printer) const {
<< "systolicDepth = " << getSystolicDepth() << ", "
<< "executionSize = " << getExecutionSize() << ", "
<< "opsPerChan = " << getOpsPerChannel() << ", "
<< "threadsPerWarp = " << getSubGroupSize() << ", "
<< "threadsPerWarp = " << getThreadsPerWarp_() << ", "
<< "warpsPerCTA = [" << llvm::ArrayRef<unsigned>(warpsPerCTA) << "], "
<< "repCluster = [" << repCluster << "], "
<< "A = [" << rA << "], "
Expand Down
Loading