We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
lib/Dialect/TritonGPU/IR/Dialect.cpp
--- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -5,6 +5,9 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" + +#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" + #include "mlir/Support/LLVM.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Utility.h" @@ -208,12 +211,15 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout) { mlir::dyn_cast<DistributedEncodingTrait>(layout)) { auto sizePerThread = distributedLayout.getSizePerThread(); auto threadsPerWarp = distributedLayout.getThreadsPerWarp(); - // ThreadsPerWarp does not align with this function for slice layout + auto warpsPerCTA = distributedLayout.getWarpsPerCTA(); + // ThreadsPerWarp and warpsPerCTA does not align with this function for + // slice layout if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(layout)) { threadsPerWarp = getThreadsPerWarp(sliceLayout.getParent()); threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim()); + warpsPerCTA = getWarpsPerCTA(sliceLayout.getParent()); + warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim()); } - auto warpsPerCTA = distributedLayout.getWarpsPerCTA(); assert(sizePerThread.size() == threadsPerWarp.size() && sizePerThread.size() == warpsPerCTA.size()); SmallVector<unsigned> shape; @@ -305,6 +311,16 @@ SmallVector<unsigned> getOrder(Attribute layout) { } if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) { auto rank = dotLayout.getWarpsPerCTA().size(); + // FIXME: delete if branch for `DpasEncodingAttr` and provide more + // general solution to make `getOrderForDotOperand` function compatible + // with Intel layouts. + // More details: + // https://github.com/intel/intel-xpu-backend-for-triton/pull/2517 + if (dyn_cast<intel::DpasEncodingAttr>(dotLayout.getParent())) { + SmallVector<unsigned> order(rank); + std::iota(order.rbegin(), order.rend(), 0); + return order; + } return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true); } if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) { @@ -1165,8 +1188,17 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const { return {}; } SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const { - return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(), - /*kMajor*/ true); + // FIXME: delete if branch for `DpasEncodingAttr` and provide more + // general solution to make `getOrderForDotOperand` function compatible + // with Intel layouts. + // More details: + // https://github.com/intel/intel-xpu-backend-for-triton/pull/2517 + if (mlir::dyn_cast<intel::DpasEncodingAttr>(getParent())) { + return ::getOrder(*this); + } else { + return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(), + /*kMajor*/ true); + } }
The text was updated successfully, but these errors were encountered:
chengjunlu
Successfully merging a pull request may close this issue.
The text was updated successfully, but these errors were encountered: