Skip to content

Commit

Permalink
Add a new interface for 3rd party Triton GPU dialect to verify the Do…
Browse files Browse the repository at this point in the history
…tOp layout.
  • Loading branch information
chengjunlu committed Dec 18, 2024
1 parent a7dcf4c commit cbf3c59
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 7 deletions.
5 changes: 5 additions & 0 deletions include/triton/Dialect/Triton/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ class DialectInferLayoutInterface
virtual LogicalResult
verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA,
Attribute operandEncodingB) const = 0;

// Verify the dotOp layout encoding is legal if it uses 3rd party Triton GPU
// dialect attribute as parent.
virtual LogicalResult verifyDotOpEncoding(unsigned opIdx, Attribute parent,
unsigned kWidth) const = 0;
};

class DialectVerifyTensorLayoutInterface
Expand Down
16 changes: 9 additions & 7 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1189,20 +1189,22 @@ LogicalResult DotOperandEncodingAttr::verify(
return success();
}

if (auto parentAttr = mlir::dyn_cast<intel::DpasEncodingAttr>(parent)) {
if (kWidth != parentAttr.getOpsPerChannel())
return emitError() << "ttg.dot_op kWidth parameter must match the "
"parent's opsPerChannel";
return success();
}

if (auto parentAttr = mlir::dyn_cast<intel::WarpEncodingAttr>(parent)) {
if (kWidth != 0)
return emitError() << "ttg.dot_op kWidth parameter is not supported "
"when the parent is a warp layout";
return success();
}

if (auto parentAttr = mlir::dyn_cast<MmaEncodingTrait>(parent)) {
Dialect &dialect = parentAttr.getDialect();
auto interface = mlir::cast<DialectInferLayoutInterface>(&dialect);
if (interface->verifyDotOpEncoding(opIdx, parent, kWidth).failed())
return emitError() << "ttg.dot_op is invalid with parent layout: "
<< parent;
return success();
}

if (auto parentAttr = mlir::dyn_cast<BlockedEncodingAttr>(parent)) {
if (kWidth != 0)
return emitError() << "ttg.dot_op kWidth parameter is not supported "
Expand Down
14 changes: 14 additions & 0 deletions third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,20 @@ struct TritonIntelGPUInferLayoutInterface
return success();
}

LogicalResult verifyDotOpEncoding(unsigned opIdx, Attribute parent,
unsigned kWidth) const override {
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError;

if (auto parentAttr = mlir::dyn_cast<intel::DpasEncodingAttr>(parent)) {
if (kWidth != parentAttr.getOpsPerChannel())
return emitError() << "ttg.dot_op kWidth parameter must match the "
"parent's opsPerChannel";
return success();
}

return emitError() << "ttg.dot_op unknown parent layout: " << parent;
}

LogicalResult
inferReshapeOpNoReorderEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
Expand Down

0 comments on commit cbf3c59

Please sign in to comment.