Skip to content

Commit

Permalink
[DPAS] Pick warpsPerCTA based on fast changing axis of A matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbaden committed Dec 13, 2024
1 parent 37b841e commit 0ecff19
Showing 1 changed file with 19 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ namespace {
SmallVector<unsigned>
getWarpsPerTile(tt::DotOp dotOp,
ttg::intel::DpasEncodingAttr::DPASCapability dpasCap,
const ArrayRef<int64_t> shape, unsigned numWarps) {
const ArrayRef<int64_t> shape, unsigned numWarps,
const SmallVector<unsigned> &order) {

auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
};
Expand Down Expand Up @@ -64,7 +66,7 @@ getWarpsPerTile(tt::DotOp dotOp,
uint32_t colRowRatio =
ceil<uint32_t>(dpasCap.executionSize, dpasCap.repeatCount);

int rowDim = rank - 2, colDim = rank - 1;
int rowDim = order[rank - 2], colDim = order[rank - 1];
do {
if (ret[rowDim] * ret[colDim] >= numWarps)
break;
Expand All @@ -78,7 +80,6 @@ getWarpsPerTile(tt::DotOp dotOp,
ret[colDim] *= 2;
}
} while (true);

return ret;
}

Expand Down Expand Up @@ -117,8 +118,22 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
Type elemType = oldAType.getElementType();
unsigned opsPerChan =
ttg::intel::DpasEncodingAttr::getOpsPerChannel(elemType);

SmallVector<unsigned> order = {0, 1};
Operation *aOp = a.getDefiningOp();
if (aOp && isa<ttg::ConvertLayoutOp>(aOp)) {
auto valueToConvert = aOp->getOperand(0);
aOp = valueToConvert.getDefiningOp();
}
if (aOp && isa<tt::LoadOp>(aOp)) {
assert(aOp->getNumResults() == 1);
Attribute layout =
cast<RankedTensorType>(aOp->getResult(0).getType()).getEncoding();
order = triton::gpu::getOrder(layout);
}

SmallVector<unsigned> warpsPerTile =
getWarpsPerTile(dotOp, dpasCap, retShape, numWarps);
getWarpsPerTile(dotOp, dpasCap, retShape, numWarps, order);
size_t rank = retShape.size();
SmallVector<unsigned> repCluster(rank, 1);

Expand Down

0 comments on commit 0ecff19

Please sign in to comment.