Skip to content

Commit

Permalink
[DPAS] Pick warpsPerCTA based on fast changing axis of A matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbaden committed Dec 18, 2024
1 parent c4201fa commit 7eee7e3
Showing 1 changed file with 19 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ namespace {

SmallVector<unsigned>
getWarpsPerTile(tt::DotOp dotOp, ttgi::DpasEncodingAttr::DPASCapability dpasCap,
const ArrayRef<int64_t> shape, unsigned numWarps) {
const ArrayRef<int64_t> shape, unsigned numWarps,
const SmallVector<unsigned> &order) {
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
};
Expand Down Expand Up @@ -64,7 +65,7 @@ getWarpsPerTile(tt::DotOp dotOp, ttgi::DpasEncodingAttr::DPASCapability dpasCap,
uint32_t colRowRatio =
ceil<uint32_t>(dpasCap.executionSize, dpasCap.repeatCount);

int rowDim = rank - 2, colDim = rank - 1;
int rowDim = order[rank - 2], colDim = order[rank - 1];
do {
if (ret[rowDim] * ret[colDim] >= numWarps)
break;
Expand All @@ -78,7 +79,6 @@ getWarpsPerTile(tt::DotOp dotOp, ttgi::DpasEncodingAttr::DPASCapability dpasCap,
ret[colDim] *= 2;
}
} while (true);

return ret;
}

Expand Down Expand Up @@ -115,9 +115,24 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {

auto dpasCap = ttgi::DpasEncodingAttr::getDPASCapability(mod);
Type elemType = oldAType.getElementType();

unsigned opsPerChan = ttgi::DpasEncodingAttr::getOpsPerChannel(elemType);

SmallVector<unsigned> order = {0, 1};
Operation *aOp = a.getDefiningOp();
if (aOp && isa<ttg::ConvertLayoutOp>(aOp)) {
auto valueToConvert = aOp->getOperand(0);
aOp = valueToConvert.getDefiningOp();
}
if (aOp && isa<tt::LoadOp>(aOp)) {
assert(aOp->getNumResults() == 1);
Attribute layout =
cast<RankedTensorType>(aOp->getResult(0).getType()).getEncoding();
order = triton::gpu::getOrder(layout);
}

SmallVector<unsigned> warpsPerTile =
getWarpsPerTile(dotOp, dpasCap, retShape, numWarps);
getWarpsPerTile(dotOp, dpasCap, retShape, numWarps, order);
size_t rank = retShape.size();
SmallVector<unsigned> repCluster(rank, 1);

Expand Down

0 comments on commit 7eee7e3

Please sign in to comment.