Skip to content

Commit

Permalink
Remove Intel code in common Triton GPU Dialect source code.
Browse files Browse the repository at this point in the history
  • Loading branch information
chengjunlu committed Dec 17, 2024
1 parent f1a893a commit ede71d8
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 62 deletions.
5 changes: 5 additions & 0 deletions include/triton/Dialect/Triton/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ class DialectInferLayoutInterface
virtual LogicalResult
verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA,
Attribute operandEncodingB) const = 0;

// Verify the dotOp layout encoding is legal if it uses 3rd party Triton GPU
// dialect attribute as parent.
virtual LogicalResult verifyDotOpEncoding(unsigned opIdx, Attribute parent,
unsigned kWidth) const = 0;
};

class DialectVerifyTensorLayoutInterface
Expand Down
107 changes: 46 additions & 61 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,16 +304,6 @@ SmallVector<unsigned> getOrder(Attribute layout) {
}
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
auto rank = dotLayout.getWarpsPerCTA().size();
// FIXME: delete if branch for `DpasEncodingAttr` and provide more
// general solution to make `getOrderForDotOperand` function compatible
// with Intel layouts.
// More details:
// https://github.com/intel/intel-xpu-backend-for-triton/pull/2517
if (dyn_cast<intel::DpasEncodingAttr>(dotLayout.getParent())) {
SmallVector<unsigned> order(rank);
std::iota(order.rbegin(), order.rend(), 0);
return order;
}
return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true);
}
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
Expand Down Expand Up @@ -1093,10 +1083,6 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
return amdWmmaParent.getTotalElemsPerThreadForOperand(
shape, eltTy, getKWidth(), getOpIdx());
}
if (auto dpasParent = mlir::dyn_cast<intel::DpasEncodingAttr>(mmaParent)) {
return dpasParent.getTotalElemsPerThreadForOperand(
shape, eltTy, getKWidth(), getOpIdx());
}
}
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) {
auto shapePerCTA = getShapePerCTA(*this, shape);
Expand Down Expand Up @@ -1161,17 +1147,8 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
return {};
}
SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
// FIXME: delete if branch for `DpasEncodingAttr` and provide more
// general solution to make `getOrderForDotOperand` function compatible
// with Intel layouts.
// More details:
// https://github.com/intel/intel-xpu-backend-for-triton/pull/2517
if (mlir::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
return ::getOrder(*this);
} else {
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
/*kMajor*/ true);
}
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
/*kMajor*/ true);
}

LogicalResult DotOperandEncodingAttr::verify(
Expand All @@ -1184,42 +1161,6 @@ LogicalResult DotOperandEncodingAttr::verify(
if (!parent) {
return emitError() << "ttg.dot_op parent paramenter cannot be null";
}
if (auto parentAttr = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (kWidth != 0 && !(parentAttr.isAmpere() || parentAttr.isHopper()))
return emitError() << "ttg.dot_op kWidth parameter can only be "
"non-zero for Ampere or Hopper MMA parent";
if (kWidth == 0 && (parentAttr.isAmpere() || parentAttr.isHopper()))
return emitError() << "ttg.dot_op kWidth parameter is mandatory for "
"Ampere or Hopper MMA parent";
if (opIdx != 0 && parentAttr.isHopper())
return emitError()
<< "ttg.dot_op opIdx parameter must be 0 for "
"Hopper MMA parent, since Hopper WGMMA only allows first "
"operand to be in registers";
return success();
}

if (auto parentAttr = mlir::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
if (kWidth != 16 && parentAttr.getVersion() == 1 ||
kWidth != 8 && parentAttr.getVersion() == 2)
return emitError() << "ttg.dot_op kWidth parameter must be 16 for "
"gfx11 and 8 for gfx12";
return success();
}

if (auto parentAttr = mlir::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
if (kWidth == 0)
return emitError() << "ttg.dot_op kWidth parameter is mandatory for "
"MFMA parent";
return success();
}

if (auto parentAttr = mlir::dyn_cast<intel::DpasEncodingAttr>(parent)) {
if (kWidth != parentAttr.getOpsPerChannel())
return emitError() << "ttg.dot_op kWidth parameter must match the "
"parent's opsPerChannel";
return success();
}

if (auto parentAttr = mlir::dyn_cast<intel::WarpEncodingAttr>(parent)) {
if (kWidth != 0)
Expand All @@ -1228,6 +1169,14 @@ LogicalResult DotOperandEncodingAttr::verify(
return success();
}

if (auto parentAttr = mlir::dyn_cast<MmaEncodingTrait>(parent)) {
Dialect &dialect = parentAttr.getDialect();
auto interface = mlir::cast<DialectInferLayoutInterface>(&dialect);
if (interface->verifyDotOpEncoding(opIdx, parent, kWidth).failed())
return emitError() << "ttg.dot_op is invalid with parent layout: "
<< parent;
}

if (auto parentAttr = mlir::dyn_cast<BlockedEncodingAttr>(parent)) {
if (kWidth != 0)
return emitError() << "ttg.dot_op kWidth parameter is not supported "
Expand Down Expand Up @@ -2678,6 +2627,42 @@ struct TritonGPUInferLayoutInterface
return success();
}

LogicalResult verifyDotOpEncoding(unsigned opIdx, Attribute parent,
unsigned kWidth) const override {
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError;
if (auto parentAttr = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (kWidth != 0 && !(parentAttr.isAmpere() || parentAttr.isHopper()))
return emitError() << "ttg.dot_op kWidth parameter can only be "
"non-zero for Ampere or Hopper MMA parent";
if (kWidth == 0 && (parentAttr.isAmpere() || parentAttr.isHopper()))
return emitError() << "ttg.dot_op kWidth parameter is mandatory for "
"Ampere or Hopper MMA parent";
if (opIdx != 0 && parentAttr.isHopper())
return emitError()
<< "ttg.dot_op opIdx parameter must be 0 for "
"Hopper MMA parent, since Hopper WGMMA only allows first "
"operand to be in registers";
return success();
}

if (auto parentAttr = mlir::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
if (kWidth != 16 && parentAttr.getVersion() == 1 ||
kWidth != 8 && parentAttr.getVersion() == 2)
return emitError() << "ttg.dot_op kWidth parameter must be 16 for "
"gfx11 and 8 for gfx12";
return success();
}

if (auto parentAttr = mlir::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
if (kWidth == 0)
return emitError() << "ttg.dot_op kWidth parameter is mandatory for "
"MFMA parent";
return success();
}

return emitError() << "ttg.dot_op unknown parent layout: " << parent;
}

// Given a src shape + encoding and a dst shape, our goal is to compute a dst
// encoding that makes the reshape a "nop". That is, if GPU thread [x,y,z]
// contains elements [a,b,c,d] before the reshape, it contains those same
Expand Down
14 changes: 14 additions & 0 deletions third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,20 @@ struct TritonIntelGPUInferLayoutInterface
return success();
}

LogicalResult verifyDotOpEncoding(unsigned opIdx, Attribute parent,
unsigned kWidth) const override {
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError;

if (auto parentAttr = mlir::dyn_cast<intel::DpasEncodingAttr>(parent)) {
if (kWidth != parentAttr.getOpsPerChannel())
return emitError() << "ttg.dot_op kWidth parameter must match the "
"parent's opsPerChannel";
return success();
}

return emitError() << "ttg.dot_op unknown parent layout: " << parent;
}

LogicalResult
inferReshapeOpNoReorderEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,12 @@ struct LoadOpConversion
const bool memoryRowMajor = (memoryLayoutInfo == "row_major");

DotOperandEncodingAttr dotLayout = getDotEncoding(tensorType).value();
auto dotOrder = dotLayout.getThreadOrder();
std::optional<LinearLayout> dotLL =
dotLayout.toLinearLayout(tensorType.getShape());
assert(dotLL.has_value() && "invalid dot layout to linear layout");
LinearEncodingAttr dotLLAttr =
LinearEncodingAttr::get(rewriter.getContext(), *dotLL);
SmallVector<unsigned> dotOrder = dotLLAttr.getThreadOrder();
size_t rank = dotOrder.size();
const bool valueRowMajor =
(dotOrder[rank - 2] == 1 && dotOrder[rank - 1] == 0);
Expand Down

0 comments on commit ede71d8

Please sign in to comment.